Skip to content

Commit b8d8ce9

Browse files
authored
[Backend] Bypass conversion for suitable blocked to dotOperand layout (#4538)
This PR extends shared memory bypass for blocked->dotOperand conversions and adds bypass check in DecomposeUnsupportedConversions and ReduceDataDuplication. This commit is a preparation step towards improving CodeGen and efficiency of skinny dot cases.
1 parent 1495116 commit b8d8ce9

File tree

8 files changed

+277
-38
lines changed

8 files changed

+277
-38
lines changed

include/triton/Analysis/Utility.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
203203

204204
bool atomicNeedsSharedMemory(Value result);
205205

206+
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);
207+
206208
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
207209

208210
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

lib/Analysis/Utility.cpp

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,75 @@ bool supportMMA(Value value, int version) {
536536
(elemTy.isInteger(8) && version >= 2);
537537
}
538538

539+
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
540+
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
541+
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
542+
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
543+
return false;
544+
auto parentLayout =
545+
dyn_cast<BlockedEncodingAttr>(dotOperandLayout.getParent());
546+
if (parentLayout == nullptr)
547+
return false;
548+
auto opShape = srcTy.getShape();
549+
auto rank = opShape.size();
550+
551+
int kDim = dotOperandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
552+
int nonKDim = dotOperandLayout.getOpIdx() == 0 ? rank - 2 : rank - 1;
553+
auto ctaLayout = blockedLayout.getCTALayout();
554+
555+
// The following logic checks that a source blocked layout matches a
556+
// destination dot operand layout. This means that given tensor in source
557+
// layout could be converted into destination layout without any data movement
558+
// between registers or threads.
559+
//
560+
// It is considered a match if
561+
// 1) Each thread in source layout holds a whole copy of all elements along
562+
// the K dimension of a tensor
563+
// 2) Distribution of data along all other non-K dimensions(Batch/M/N)
564+
// matches between source and destination parent layouts.
565+
//
566+
// First condition comes from the property of dot operand layout with Blocked
567+
// parent: size per threads along K dimension equals size of the tensor along
568+
// K. Second condition comes from other property: dot operand layout
569+
// inherits non-K dimensions from it's parent layout.
570+
//
571+
// clang-format off
572+
//
573+
// For example, following conversion is a no op:
574+
// tensor<128x32xf16, #blocked<{sizePerThread = [2, 32], threadsPerWarp = [32, 1]}>>
575+
// ->
576+
// tensor<128x32xf16, #dot_op<{opIdx=0, parent=#blocked<{sizePerThread = [2, 8], threadsPerWarp = [32, 1]}>>>
577+
//
578+
// clang-format on
579+
bool ctaLayoutCompatible =
580+
ctaLayout.getCTASplitNum()[kDim] == 1 &&
581+
blockedLayout.getCTALayout() == parentLayout.getCTALayout();
582+
bool threadHoldsWholeKDim =
583+
blockedLayout.getSizePerThread()[kDim] == opShape[kDim];
584+
bool nonKDimCompatible =
585+
blockedLayout.getOrder() == parentLayout.getOrder() &&
586+
blockedLayout.getSizePerThread()[nonKDim] ==
587+
parentLayout.getSizePerThread()[nonKDim] &&
588+
blockedLayout.getThreadsPerWarp()[nonKDim] ==
589+
parentLayout.getThreadsPerWarp()[nonKDim] &&
590+
blockedLayout.getWarpsPerCTA()[nonKDim] ==
591+
parentLayout.getWarpsPerCTA()[nonKDim];
592+
bool matrixDimsCompatible =
593+
ctaLayoutCompatible && threadHoldsWholeKDim && nonKDimCompatible;
594+
if (rank == 2)
595+
return matrixDimsCompatible;
596+
597+
// additional check for batch dimension if it is present
598+
assert(rank == 3);
599+
bool bDimCompatible =
600+
blockedLayout.getSizePerThread()[0] ==
601+
parentLayout.getSizePerThread()[0] &&
602+
blockedLayout.getThreadsPerWarp()[0] ==
603+
parentLayout.getThreadsPerWarp()[0] &&
604+
blockedLayout.getWarpsPerCTA()[0] == parentLayout.getWarpsPerCTA()[0];
605+
return matrixDimsCompatible && bDimCompatible;
606+
}
607+
539608
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
540609
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
541610
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
@@ -625,12 +694,13 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
625694
}
626695

627696
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
628-
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut` and
629-
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
630-
// checks.
697+
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`,
698+
// `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully
699+
// subsumed by the linear-layout checks.
631700
// TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
632701
// supported yet in Triton's backend.
633702
return !cvtReordersRegisters(srcTy, dstTy) &&
703+
!isBlockedToDotShortcut(srcTy, dstTy) &&
634704
!isMmaToDotShortcut(srcTy, dstTy) &&
635705
!isMfmaToDotShortcut(srcTy, dstTy);
636706
}

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,36 @@ struct ConvertLayoutOpConversion
232232
const TargetInfoBase &targetInfo;
233233
};
234234

235+
struct ConvertLayoutOpBlockedToDotOpShortcutConversion
236+
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
237+
const TargetInfoBase &targetInfo;
238+
explicit ConvertLayoutOpBlockedToDotOpShortcutConversion(
239+
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
240+
PatternBenefit benefit = 1)
241+
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
242+
}
243+
244+
LogicalResult
245+
matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor,
246+
ConversionPatternRewriter &rewriter) const override {
247+
MLIRContext *ctx = op.getContext();
248+
249+
const auto &shape = op.getType().getShape();
250+
auto srcTy = op.getSrc().getType();
251+
auto dstTy = op.getType();
252+
auto dstDotEncoding = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
253+
if (!dstDotEncoding)
254+
return failure();
255+
if (!isa<BlockedEncodingAttr>(srcTy.getEncoding()) ||
256+
!isa<BlockedEncodingAttr>(dstDotEncoding.getParent()))
257+
return failure();
258+
if (cvtNeedsSharedMemory(srcTy, dstTy))
259+
return failure();
260+
rewriter.replaceOp(op, adaptor.getSrc());
261+
return success();
262+
}
263+
};
264+
235265
struct ConvertLayoutOpUsingLinearLayoutsConversion
236266
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
237267
const TargetInfoBase &targetInfo;
@@ -657,5 +687,7 @@ void mlir::triton::populateConvertLayoutOpToLLVMPatterns(
657687
// one left.
658688
mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern(
659689
typeConverter, targetInfo, patterns, benefit.getBenefit() + 1);
690+
patterns.add<ConvertLayoutOpBlockedToDotOpShortcutConversion>(
691+
typeConverter, targetInfo, benefit);
660692
patterns.add<ConvertLayoutOpConversion>(typeConverter, targetInfo, benefit);
661693
}

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
8383
OpBuilder builder(cvtOp);
8484
auto srcType = cast<RankedTensorType>(cvtOp.getSrc().getType());
8585
auto dstType = cast<RankedTensorType>(cvtOp.getType());
86+
if (!cvtNeedsSharedMemory(srcType, dstType))
87+
return;
8688
auto srcBlocked =
8789
dyn_cast<triton::gpu::BlockedEncodingAttr>(srcType.getEncoding());
8890
auto dstDotOp =

lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,8 @@ class TritonGPUReduceDataDuplicationPass
4242
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
4343
if (!dstDotOp)
4444
return;
45-
if (auto srcMmaEncoding =
46-
dyn_cast<triton::gpu::NvidiaMmaEncodingAttr>(srcEncoding)) {
47-
48-
if (srcMmaEncoding.getVersionMajor() != 2 ||
49-
(srcMmaEncoding.getWarpsPerCTA()[1] == 1 &&
50-
dstDotOp.getParent() == srcMmaEncoding))
51-
return;
52-
}
53-
if (auto srcMfmaEncoding =
54-
dyn_cast<triton::gpu::AMDMfmaEncodingAttr>(srcEncoding)) {
55-
56-
if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 &&
57-
srcMfmaEncoding.getIsTransposed() &&
58-
dstDotOp.getParent() == srcMfmaEncoding)
59-
return;
60-
}
45+
if (!cvtNeedsSharedMemory(srcType, dstType))
46+
return;
6147
auto srcOrder = triton::gpu::getOrder(srcEncoding);
6248
auto rank = srcOrder.size();
6349
SmallVector<unsigned> sharedOrder;
Lines changed: 88 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,105 @@
1-
// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx1130 | FileCheck %s
1+
// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions | FileCheck %s
22

3-
// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
4-
// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
5-
// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}>
6-
// CHECK: wmma_to_wmma_dot_op
3+
// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
4+
// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
5+
// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}>
6+
// CHECK-LABEL: wmma_to_wmma_dot_op
77
#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}>
8-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
8+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} {
99
tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) {
10-
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[WMMA]]> -> tensor<16x16xf16, #[[BLOCKED]]>
11-
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[SHARED]], #triton_gpu.shared_memory>
12-
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>>
10+
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]>
11+
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory>
12+
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>>
1313
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
1414
tt.return
1515
}
1616
}
1717

1818
// -----
1919

20-
// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
21-
// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
22-
// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}>
23-
// CHECK: wmma_to_wmma_dot3d_op
20+
// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
21+
// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
22+
// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}>
23+
// CHECK-LABEL: wmma_to_wmma_dot3d_op
2424
#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}>
2525
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
2626
tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) {
27-
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[WMMA]]> -> tensor<2x16x16xf16, #[[BLOCKED]]>
28-
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[SHARED]], #triton_gpu.shared_memory>
29-
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>>
27+
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]>
28+
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory>
29+
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>>
3030
%0 = triton_gpu.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
3131
tt.return
3232
}
3333
}
34+
35+
// -----
36+
37+
// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx1130
38+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}>
39+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} {
40+
tt.func @blocked_to_dot_op_shortcut_gfx1130(%arg0: tensor<32x32xf16, #blocked>) {
41+
// CHECK-NOT: triton_gpu.local_alloc
42+
// CHECK: triton_gpu.convert_layout
43+
// CHECK-NOT: triton_gpu.local_alloc
44+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
45+
tt.return
46+
}
47+
}
48+
49+
// -----
50+
51+
// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx940
52+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
53+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
54+
tt.func @blocked_to_dot_op_shortcut_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
55+
// CHECK-NOT: triton_gpu.local_alloc
56+
// CHECK: triton_gpu.convert_layout
57+
// CHECK-NOT: triton_gpu.local_alloc
58+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
59+
tt.return
60+
}
61+
}
62+
63+
// -----
64+
65+
// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_elems_gfx940
66+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
67+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
68+
tt.func @neg_blocked_to_dot_op_incompatible_elems_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
69+
// CHECK-NOT: triton_gpu.convert_layout
70+
// CHECK: triton_gpu.local_alloc
71+
// CHECK: triton_gpu.local_load
72+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
73+
tt.return
74+
}
75+
}
76+
77+
// -----
78+
79+
// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_threads_gfx940
80+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
81+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0]}>
82+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
83+
tt.func @neg_blocked_to_dot_op_incompatible_threads_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
84+
// CHECK-NOT: triton_gpu.convert_layout
85+
// CHECK: triton_gpu.local_alloc
86+
// CHECK: triton_gpu.local_load
87+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
88+
tt.return
89+
}
90+
}
91+
92+
// -----
93+
94+
// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_warp_gfx940
95+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
96+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
97+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
98+
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
99+
// CHECK-NOT: triton_gpu.convert_layout
100+
// CHECK: triton_gpu.local_alloc
101+
// CHECK: triton_gpu.local_load
102+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
103+
tt.return
104+
}
105+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s
2+
3+
// CHECK-LABEL: blocked_to_dot_op_shortcut_warp32
4+
#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [0, 1]}>
5+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
6+
tt.func @blocked_to_dot_op_shortcut_warp32(%arg0: tensor<32x32xf16, #blocked>, %arg1: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) {
7+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
8+
// CHECK-NOT: load
9+
tt.return
10+
}
11+
}
12+
13+
// -----
14+
15+
// CHECK-LABEL: blocked_to_dot_op_shortcut_warp64
16+
#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [2, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
17+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
18+
tt.func @blocked_to_dot_op_shortcut_warp64(%arg0: tensor<32x32xf16, #blocked>) {
19+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
20+
// CHECK-NOT: load
21+
tt.return
22+
}
23+
}
24+
25+
// -----
26+
27+
// CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp32
28+
#blocked = #triton_gpu.blocked<{sizePerThread = [2, 32, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}>
29+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
30+
tt.func @blocked_to_dot3d_op_shortcut_warp32(%arg0: tensor<8x32x32xf16, #blocked>) {
31+
%0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
32+
// CHECK-NOT: load
33+
tt.return
34+
}
35+
}
36+
37+
// -----
38+
39+
// CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp64
40+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32, 1], threadsPerWarp = [1, 2, 32], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}>
41+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
42+
tt.func @blocked_to_dot3d_op_shortcut_warp64(%arg0: tensor<8x32x32xf16, #blocked>) {
43+
%0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
44+
// CHECK-NOT: load
45+
tt.return
46+
}
47+
}

0 commit comments

Comments
 (0)