@@ -455,9 +455,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
455455 StringAttr kBlock = str_attr (" block" );
456456
457457 LinearLayout comp = dstLayout.invertAndCompose (srcLayout);
458- std::optional<LinearLayout> conversion = comp.divideRight (
459- LinearLayout::identity1D (comp.getInDimSize (kWarp ), kWarp , kWarp ) *
460- LinearLayout::identity1D (comp.getInDimSize (kBlock ), kBlock , kBlock ));
458+ std::optional<LinearLayout> conversion =
459+ comp.quotient (kBlock )->quotient (kWarp );
461460 assert (conversion && " Expecting valid conversion" );
462461 // Expected conversion is:
463462 // - register=1 -> (0, 1)
@@ -516,85 +515,87 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
516515 const auto &shape = op.getType ().getShape ();
517516 auto srcTy = op.getSrc ().getType ();
518517 auto dstTy = op.getType ();
519- std::optional<LinearLayout> srcLayout =
520- toLinearLayout (shape, srcTy.getEncoding ());
521- std::optional<LinearLayout> dstLayout =
522- toLinearLayout (shape, dstTy.getEncoding ());
523- if (!srcLayout.has_value () || !dstLayout.has_value ()) {
524- return failure ();
525- }
526518
527- // There are four cases to handle.
528- //
529- // 1. Transfer between values in the same thread, in which case we simply
530- // reorder the elements of adaptor.getSrc().
531- // 2. Transfer between values in the same warp, in which case we try to
532- // move values using warp shuffles, though if the pattern is complicated
533- // enough we may fall back to using shared memory (case 3).
534- // 3. Transfer between values in the same CTA, in which case we move values
535- // through shared memory.
536- // 4. Transfer between values in different CTAs, in which case we move
537- // values through distributed shared memory.
538- //
539- // We can tell which case we're in by examining `conversion`.
540- // For example, if the block -> block mapping is an identity layout: {1, 2,
541- // 4, ...}, then there's no movement between data in different CTAs, and we
542- // know we're not in case 4.
543- if (cvtReordersRegisters (srcTy, dstTy)) { // Case 1.
544- return transferWithinThread (op, *srcLayout, *dstLayout, adaptor,
545- rewriter);
519+ auto conversion = minimalCvtLayout (srcTy, dstTy);
520+ if (!conversion.has_value ()) {
521+ return rewriter.notifyMatchFailure (
522+ op, " NYI. srcTy and/or dstTy don't implement LLs yet" );
546523 }
524+ LinearLayout srcLayout =
525+ *toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
526+ LinearLayout dstLayout =
527+ *toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
547528
548- if (cvtNeedsWarpShuffle (srcTy, dstTy)) { // Case 2.
549- return transferWithinLane (op, *srcLayout, *dstLayout, adaptor, rewriter);
550- }
529+ StringAttr kBlock = str_attr (" block" );
530+ StringAttr kWarp = str_attr (" warp" );
531+ StringAttr kLane = str_attr (" lane" );
532+ StringAttr kRegister = str_attr (" register" );
551533
552- // TODO: match transferWithinBlockOrGroup from
553- // TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
554- return transferWithinBlockGroup (op, *srcLayout, *dstLayout, adaptor,
555- rewriter);
534+ assert (to_vector (conversion->getInDimNames ()) ==
535+ to_vector (conversion->getOutDimNames ()));
536+ auto dims = conversion->getInDimNames ();
537+ if (llvm::is_contained (dims, str_attr (" block" ))) {
538+ // Case 1: Transfer between values in different CTAs.
539+ // This requires moving values through distributed shared memory.
540+ return rewriter.notifyMatchFailure (
541+ op, " NYI: Transfer between different CTAs" );
542+ } else if (llvm::is_contained (dims, str_attr (" warp" ))) {
543+ return rewriter.notifyMatchFailure (
544+ op, " NYI: Transfer between different warps" );
545+ } else if (llvm::is_contained (dims, str_attr (" lane" ))) {
546+ // Case 2: Transfer between values in the same CTA, in which case we move
547+ // values through shared memory.
548+ // If the operation is a supported sub-group shuffle, perform via shuffle
549+ // operations.
550+ if (isSubGroupShuffle (srcLayout, dstLayout) &&
551+ isSupportedSubGroupShuffle (op, adaptor)) {
552+ performSubGroupShuffle (op, srcLayout, dstLayout, adaptor, rewriter);
553+ return success ();
554+ }
555+ // If the operation is a supported sub-group transposition, perform via
556+ // SLM.
557+ if (isSubGroupTranspose (srcLayout, dstLayout) &&
558+ isSupportedSubGroupTranspose (op, adaptor)) {
559+ performSubGroupTranspose (op, srcLayout, dstLayout, adaptor, rewriter);
560+ return success ();
561+ }
562+ // TODO(jlebar): Implement me.
563+ return failure ();
564+ } else if (llvm::is_contained (dims, str_attr (" register" ))) {
565+ // Case 4. Transfer between values in the same thread, in which case we
566+ // simply reorder the elements of adaptor.getSrc().
567+ return transferWithinThread (
568+ op, dstLayout.getFreeVariableMasks ()[kRegister ],
569+ dstLayout.getInDimSize (kRegister ), *conversion, adaptor, rewriter);
570+ } else {
571+ // The two layouts are equivalent. We should probably remove these in
572+ // RemoveLayoutConversion.
573+ rewriter.replaceOp (op, adaptor.getSrc ());
574+ return success ();
575+ }
556576 }
557577
558578 LogicalResult
559- transferWithinThread (ConvertLayoutOp op, const LinearLayout &srcLayout ,
560- const LinearLayout &dstLayout , OpAdaptor adaptor,
579+ transferWithinThread (ConvertLayoutOp op, int32_t regMasks, int32_t numRegs ,
580+ const LinearLayout &conversion , OpAdaptor adaptor,
561581 ConversionPatternRewriter &rewriter) const {
562582 MLIRContext *ctx = op.getContext ();
563583 auto loc = op.getLoc ();
564584 StringAttr kRegister = str_attr (" register" );
565- StringAttr kLane = str_attr (" lane" );
566- StringAttr kWarp = str_attr (" warp" );
567- StringAttr kBlock = str_attr (" block" );
568-
569- // There are three possible cases:
570- //
571- // 1. `srcLayout` has the same number of registers as `dstLayout`.
572- // 2. `srcLayout` has fewer registers than `dstLayout`.
573- // 3. `srcLayout` has more registers than `dstLayout`.
574- //
575- // In the second case `srcLayout . dstLayout^-1` is not surjective
576- // because not all destination registers are covered.
577- // Since the goal is to cover all of the destination
578- // registers, we can instead use `dstLayout . srcLayout^-1`.
579- LinearLayout conversion = dstLayout.invertAndCompose (srcLayout);
580- auto dstToSrc = conversion.divideRight (
581- LinearLayout::identity1D (conversion.getInDimSize (kLane ), kLane , kLane ) *
582- LinearLayout::identity1D (conversion.getInDimSize (kWarp ), kWarp , kWarp ) *
583- LinearLayout::identity1D (conversion.getInDimSize (kBlock ), kBlock ,
584- kBlock ));
585-
586585 assert (!cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
587- assert (ArrayRef (to_vector (dstToSrc->getInDimNames ())) ==
588- ArrayRef{kRegister });
589- assert (ArrayRef (to_vector (dstToSrc->getOutDimNames ())) ==
590- ArrayRef{kRegister });
591586
592587 auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
593- SmallVector<Value> outVals;
594- outVals.resize (dstToSrc->getInDimSize (kRegister ));
595- for (int i = 0 ; i < dstToSrc->getInDimSize (kRegister ); i++) {
596- auto srcIdx = dstToSrc->apply ({{kRegister , i}});
597- outVals[i] = inVals[srcIdx.begin ()->second ];
588+ SmallVector<Value> outVals (numRegs);
589+ for (int i = 0 ; i < outVals.size (); i++) {
590+ // Remove free masks from the register index
591+ // For example, if idx = 0b00111, and masks = 0b00100, then we get
592+ // 0b00011. It means that register 7 (0b111) has the same value as
593+ // register 3 (0b011).
594+ auto idx = i & (~regMasks);
595+ auto srcIdx = conversion.hasInDim (kRegister )
596+ ? conversion.apply ({{kRegister , idx}}).begin ()->second
597+ : idx;
598+ outVals[i] = inVals[srcIdx];
598599 }
599600 Value result = packLLElements (loc, getTypeConverter (), outVals, rewriter,
600601 op.getType ());
@@ -611,9 +612,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
611612 StringAttr kBlock = str_attr (" block" );
612613
613614 LinearLayout comp = dstLayout.invertAndCompose (srcLayout);
614- std::optional<LinearLayout> conversion = comp.divideRight (
615- LinearLayout::identity1D (comp.getInDimSize (kWarp ), kWarp , kWarp ) *
616- LinearLayout::identity1D (comp.getInDimSize (kBlock ), kBlock , kBlock ));
615+ std::optional<LinearLayout> conversion =
616+ comp.quotient (kBlock )->quotient (kWarp );
617617 assert (conversion && " Expecting valid conversion" );
618618 // TODO: Support more kind of shuffles.
619619 // Expected conversion is:
@@ -667,11 +667,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
667667 StringAttr kWarp = str_attr (" warp" );
668668 StringAttr kBlock = str_attr (" block" );
669669 LinearLayout comp = dstLayout.invertAndCompose (srcLayout);
670- std::optional<LinearLayout> conversion = comp.divideRight (
671- LinearLayout::identity1D (comp.getInDimSize (kWarp ), kWarp , kWarp ) *
672- LinearLayout::identity1D (comp.getInDimSize (kBlock ), kBlock , kBlock ));
673- assert (conversion && " Expecting valid layout" );
674- int32_t subGroupSize = conversion->getOutDimSize (kLane );
670+ LinearLayout conversion = *comp.quotient (kBlock )->quotient (kWarp );
671+ int32_t subGroupSize = conversion.getOutDimSize (kLane );
675672
676673 Location loc = op.getLoc ();
677674
@@ -772,28 +769,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
772769 .Default (false );
773770 }
774771
775- LogicalResult transferWithinLane (ConvertLayoutOp op,
776- const LinearLayout &srcLayout,
777- const LinearLayout &dstLayout,
778- OpAdaptor adaptor,
779- ConversionPatternRewriter &rewriter) const {
780- // If the operation is a supported sub-group shuffle, perform via shuffle
781- // operations.
782- if (isSubGroupShuffle (srcLayout, dstLayout) &&
783- isSupportedSubGroupShuffle (op, adaptor)) {
784- performSubGroupShuffle (op, srcLayout, dstLayout, adaptor, rewriter);
785- return success ();
786- }
787- // If the operation is a supported sub-group transposition, perform via SLM.
788- if (isSubGroupTranspose (srcLayout, dstLayout) &&
789- isSupportedSubGroupTranspose (op, adaptor)) {
790- performSubGroupTranspose (op, srcLayout, dstLayout, adaptor, rewriter);
791- return success ();
792- }
793- // TODO(jlebar): Implement me.
794- return failure ();
795- }
796-
797772 bool isValidTypeForSubGroupTranspose (Type type) const {
798773 return TypeSwitch<Type, bool >(type)
799774 .Case ([](IntegerType intTy) {
@@ -967,14 +942,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
967942 }
968943 return unwrapFromVectors (loc, transposedVecs, rewriter);
969944 }
970-
971- LogicalResult
972- transferWithinBlockGroup (ConvertLayoutOp op, const LinearLayout &srcLayout,
973- const LinearLayout &dstLayout, OpAdaptor adaptor,
974- ConversionPatternRewriter &rewriter) const {
975- // TODO(jlebar): Implement me.
976- return failure ();
977- }
978945};
979946
980947} // namespace
0 commit comments