@@ -699,75 +699,6 @@ bool supportMMA(Value value, int version) {
699699 (elemTy.isInteger (8 ) && version >= 2 );
700700}
701701
702- bool isBlockedToDotShortcut (RankedTensorType srcTy, RankedTensorType dstTy) {
703- auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding ());
704- auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding ());
705- if (blockedLayout == nullptr || dotOperandLayout == nullptr )
706- return false ;
707- auto parentLayout =
708- dyn_cast<BlockedEncodingAttr>(dotOperandLayout.getParent ());
709- if (parentLayout == nullptr )
710- return false ;
711- auto opShape = srcTy.getShape ();
712- auto rank = opShape.size ();
713-
714- int kDim = dotOperandLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
715- int nonKDim = dotOperandLayout.getOpIdx () == 0 ? rank - 2 : rank - 1 ;
716- auto ctaLayout = blockedLayout.getCTALayout ();
717-
718- // The following logic checks that a source blocked layout matches a
719- // destination dot operand layout. This means that given tensor in source
720- // layout could be converted into destination layout without any data movement
721- // between registers or threads.
722- //
723- // It is considered a match if
724- // 1) Each thread in source layout holds a whole copy of all elements along
725- // the K dimension of a tensor
726- // 2) Distribution of data along all other non-K dimensions(Batch/M/N)
727- // matches between source and destination parent layouts.
728- //
729- // First condition comes from the property of dot operand layout with Blocked
730- // parent: size per threads along K dimension equals size of the tensor along
731- // K. Second condition comes from other property: dot operand layout
732- // inherits non-K dimensions from it's parent layout.
733- //
734- // clang-format off
735- //
736- // For example, following conversion is a no op:
737- // tensor<128x32xf16, #blocked<{sizePerThread = [2, 32], threadsPerWarp = [32, 1]}>>
738- // ->
739- // tensor<128x32xf16, #dot_op<{opIdx=0, parent=#blocked<{sizePerThread = [2, 8], threadsPerWarp = [32, 1]}>>>
740- //
741- // clang-format on
742- bool ctaLayoutCompatible =
743- ctaLayout.getCTASplitNum ()[kDim ] == 1 &&
744- blockedLayout.getCTALayout () == parentLayout.getCTALayout ();
745- bool threadHoldsWholeKDim =
746- blockedLayout.getSizePerThread ()[kDim ] == opShape[kDim ];
747- bool nonKDimCompatible =
748- blockedLayout.getOrder () == parentLayout.getOrder () &&
749- blockedLayout.getSizePerThread ()[nonKDim] ==
750- parentLayout.getSizePerThread ()[nonKDim] &&
751- blockedLayout.getThreadsPerWarp ()[nonKDim] ==
752- parentLayout.getThreadsPerWarp ()[nonKDim] &&
753- blockedLayout.getWarpsPerCTA ()[nonKDim] ==
754- parentLayout.getWarpsPerCTA ()[nonKDim];
755- bool matrixDimsCompatible =
756- ctaLayoutCompatible && threadHoldsWholeKDim && nonKDimCompatible;
757- if (rank == 2 )
758- return matrixDimsCompatible;
759-
760- // additional check for batch dimension if it is present
761- assert (rank == 3 );
762- bool bDimCompatible =
763- blockedLayout.getSizePerThread ()[0 ] ==
764- parentLayout.getSizePerThread ()[0 ] &&
765- blockedLayout.getThreadsPerWarp ()[0 ] ==
766- parentLayout.getThreadsPerWarp ()[0 ] &&
767- blockedLayout.getWarpsPerCTA ()[0 ] == parentLayout.getWarpsPerCTA ()[0 ];
768- return matrixDimsCompatible && bDimCompatible;
769- }
770-
771702// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
772703bool matchMmaV3AndDotOperandLayout (RankedTensorType srcTy,
773704 RankedTensorType dstTy) {
@@ -854,13 +785,11 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
854785}
855786
856787bool cvtNeedsSharedMemory (RankedTensorType srcTy, RankedTensorType dstTy) {
857- // TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and
858- // `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
859- // checks.
788+ // TODO(jlebar): Remove these special cases `isMfmaToDotShortcut` once
789+ // they're fully subsumed by the linear-layout checks.
860790 return !cvtReordersRegisters (srcTy, dstTy) &&
861791 !(cvtNeedsWarpShuffle (srcTy, dstTy) &&
862792 getWarpLayoutConvertDecomposition (srcTy, dstTy)) &&
863- !isBlockedToDotShortcut (srcTy, dstTy) &&
864793 !matchMmaV3AndDotOperandLayout (srcTy, dstTy) &&
865794 // to be removed when generalized warp shuffle conversions
866795 // are ready:
0 commit comments