Skip to content

Commit 74b4714

Browse files
authored
[NFC] Kill isBlockedToDotShortcut (#6003)
LLs already cover this and more! I saw the pass had not been removed yet. I remove it in triton-lang/triton#6006
1 parent 1d3e2b5 commit 74b4714

File tree

2 files changed

+2
-75
lines changed

2 files changed

+2
-75
lines changed

include/triton/Analysis/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,6 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
242242

243243
bool atomicNeedsSharedMemory(Value result);
244244

245-
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
246-
247245
// Return true if the src and dst layout match.
248246
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
249247
RankedTensorType dstTy);

lib/Analysis/Utility.cpp

Lines changed: 2 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -699,75 +699,6 @@ bool supportMMA(Value value, int version) {
699699
(elemTy.isInteger(8) && version >= 2);
700700
}
701701

702-
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
703-
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
704-
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
705-
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
706-
return false;
707-
auto parentLayout =
708-
dyn_cast<BlockedEncodingAttr>(dotOperandLayout.getParent());
709-
if (parentLayout == nullptr)
710-
return false;
711-
auto opShape = srcTy.getShape();
712-
auto rank = opShape.size();
713-
714-
int kDim = dotOperandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
715-
int nonKDim = dotOperandLayout.getOpIdx() == 0 ? rank - 2 : rank - 1;
716-
auto ctaLayout = blockedLayout.getCTALayout();
717-
718-
// The following logic checks that a source blocked layout matches a
719-
// destination dot operand layout. This means that given tensor in source
720-
// layout could be converted into destination layout without any data movement
721-
// between registers or threads.
722-
//
723-
// It is considered a match if
724-
// 1) Each thread in source layout holds a whole copy of all elements along
725-
// the K dimension of a tensor
726-
// 2) Distribution of data along all other non-K dimensions(Batch/M/N)
727-
// matches between source and destination parent layouts.
728-
//
729-
// First condition comes from the property of dot operand layout with Blocked
730-
// parent: size per threads along K dimension equals size of the tensor along
731-
// K. Second condition comes from other property: dot operand layout
732-
// inherits non-K dimensions from it's parent layout.
733-
//
734-
// clang-format off
735-
//
736-
// For example, following conversion is a no op:
737-
// tensor<128x32xf16, #blocked<{sizePerThread = [2, 32], threadsPerWarp = [32, 1]}>>
738-
// ->
739-
// tensor<128x32xf16, #dot_op<{opIdx=0, parent=#blocked<{sizePerThread = [2, 8], threadsPerWarp = [32, 1]}>>>
740-
//
741-
// clang-format on
742-
bool ctaLayoutCompatible =
743-
ctaLayout.getCTASplitNum()[kDim] == 1 &&
744-
blockedLayout.getCTALayout() == parentLayout.getCTALayout();
745-
bool threadHoldsWholeKDim =
746-
blockedLayout.getSizePerThread()[kDim] == opShape[kDim];
747-
bool nonKDimCompatible =
748-
blockedLayout.getOrder() == parentLayout.getOrder() &&
749-
blockedLayout.getSizePerThread()[nonKDim] ==
750-
parentLayout.getSizePerThread()[nonKDim] &&
751-
blockedLayout.getThreadsPerWarp()[nonKDim] ==
752-
parentLayout.getThreadsPerWarp()[nonKDim] &&
753-
blockedLayout.getWarpsPerCTA()[nonKDim] ==
754-
parentLayout.getWarpsPerCTA()[nonKDim];
755-
bool matrixDimsCompatible =
756-
ctaLayoutCompatible && threadHoldsWholeKDim && nonKDimCompatible;
757-
if (rank == 2)
758-
return matrixDimsCompatible;
759-
760-
// additional check for batch dimension if it is present
761-
assert(rank == 3);
762-
bool bDimCompatible =
763-
blockedLayout.getSizePerThread()[0] ==
764-
parentLayout.getSizePerThread()[0] &&
765-
blockedLayout.getThreadsPerWarp()[0] ==
766-
parentLayout.getThreadsPerWarp()[0] &&
767-
blockedLayout.getWarpsPerCTA()[0] == parentLayout.getWarpsPerCTA()[0];
768-
return matrixDimsCompatible && bDimCompatible;
769-
}
770-
771702
// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
772703
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
773704
RankedTensorType dstTy) {
@@ -854,13 +785,11 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
854785
}
855786

856787
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
857-
// TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and
858-
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
859-
// checks.
788+
// TODO(jlebar): Remove these special cases `isMfmaToDotShortcut` once
789+
// they're fully subsumed by the linear-layout checks.
860790
return !cvtReordersRegisters(srcTy, dstTy) &&
861791
!(cvtNeedsWarpShuffle(srcTy, dstTy) &&
862792
getWarpLayoutConvertDecomposition(srcTy, dstTy)) &&
863-
!isBlockedToDotShortcut(srcTy, dstTy) &&
864793
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
865794
// to be removed when generalized warp shuffle conversions
866795
// are ready:

0 commit comments

Comments
 (0)