Skip to content

Commit 549c9e1

Browse files
Merge commit '5f9bb95d657268b3c33f9e7e5dbbde4510d9f704'
2 parents 414eba6 + 5f9bb95 commit 549c9e1

File tree

17 files changed

+369
-47
lines changed

17 files changed

+369
-47
lines changed

include/triton/Analysis/Utility.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
203203

204204
bool atomicNeedsSharedMemory(Value result);
205205

206+
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);
207+
206208
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
207209

208210
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

lib/Analysis/Utility.cpp

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,75 @@ bool supportMMA(Value value, int version) {
543543
(elemTy.isInteger(8) && version >= 2);
544544
}
545545

546+
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
547+
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
548+
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
549+
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
550+
return false;
551+
auto parentLayout =
552+
dyn_cast<BlockedEncodingAttr>(dotOperandLayout.getParent());
553+
if (parentLayout == nullptr)
554+
return false;
555+
auto opShape = srcTy.getShape();
556+
auto rank = opShape.size();
557+
558+
int kDim = dotOperandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
559+
int nonKDim = dotOperandLayout.getOpIdx() == 0 ? rank - 2 : rank - 1;
560+
auto ctaLayout = blockedLayout.getCTALayout();
561+
562+
// The following logic checks that a source blocked layout matches a
563+
// destination dot operand layout. This means that given tensor in source
564+
// layout could be converted into destination layout without any data movement
565+
// between registers or threads.
566+
//
567+
// It is considered a match if
568+
// 1) Each thread in source layout holds a whole copy of all elements along
569+
// the K dimension of a tensor
570+
// 2) Distribution of data along all other non-K dimensions(Batch/M/N)
571+
// matches between source and destination parent layouts.
572+
//
573+
// First condition comes from the property of dot operand layout with Blocked
574+
// parent: size per threads along K dimension equals size of the tensor along
575+
// K. Second condition comes from other property: dot operand layout
576+
// inherits non-K dimensions from it's parent layout.
577+
//
578+
// clang-format off
579+
//
580+
// For example, following conversion is a no op:
581+
// tensor<128x32xf16, #blocked<{sizePerThread = [2, 32], threadsPerWarp = [32, 1]}>>
582+
// ->
583+
// tensor<128x32xf16, #dot_op<{opIdx=0, parent=#blocked<{sizePerThread = [2, 8], threadsPerWarp = [32, 1]}>>>
584+
//
585+
// clang-format on
586+
bool ctaLayoutCompatible =
587+
ctaLayout.getCTASplitNum()[kDim] == 1 &&
588+
blockedLayout.getCTALayout() == parentLayout.getCTALayout();
589+
bool threadHoldsWholeKDim =
590+
blockedLayout.getSizePerThread()[kDim] == opShape[kDim];
591+
bool nonKDimCompatible =
592+
blockedLayout.getOrder() == parentLayout.getOrder() &&
593+
blockedLayout.getSizePerThread()[nonKDim] ==
594+
parentLayout.getSizePerThread()[nonKDim] &&
595+
blockedLayout.getThreadsPerWarp()[nonKDim] ==
596+
parentLayout.getThreadsPerWarp()[nonKDim] &&
597+
blockedLayout.getWarpsPerCTA()[nonKDim] ==
598+
parentLayout.getWarpsPerCTA()[nonKDim];
599+
bool matrixDimsCompatible =
600+
ctaLayoutCompatible && threadHoldsWholeKDim && nonKDimCompatible;
601+
if (rank == 2)
602+
return matrixDimsCompatible;
603+
604+
// additional check for batch dimension if it is present
605+
assert(rank == 3);
606+
bool bDimCompatible =
607+
blockedLayout.getSizePerThread()[0] ==
608+
parentLayout.getSizePerThread()[0] &&
609+
blockedLayout.getThreadsPerWarp()[0] ==
610+
parentLayout.getThreadsPerWarp()[0] &&
611+
blockedLayout.getWarpsPerCTA()[0] == parentLayout.getWarpsPerCTA()[0];
612+
return matrixDimsCompatible && bDimCompatible;
613+
}
614+
546615
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
547616
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
548617
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
@@ -632,13 +701,14 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
632701
}
633702

634703
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
635-
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut` and
636-
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
637-
// checks.
704+
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`,
705+
// `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully
706+
// subsumed by the linear-layout checks.
638707
// TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
639708
// supported yet in Triton's backend.
640709
return !cvtReordersRegisters(srcTy, dstTy) &&
641710
!triton::gpu::intel::isDpasToDotShortcut(srcTy, dstTy) &&
711+
!isBlockedToDotShortcut(srcTy, dstTy) &&
642712
!isMmaToDotShortcut(srcTy, dstTy) &&
643713
!isMfmaToDotShortcut(srcTy, dstTy);
644714
}

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,36 @@ struct ConvertLayoutOpConversion
232232
const TargetInfoBase &targetInfo;
233233
};
234234

235+
struct ConvertLayoutOpBlockedToDotOpShortcutConversion
236+
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
237+
const TargetInfoBase &targetInfo;
238+
explicit ConvertLayoutOpBlockedToDotOpShortcutConversion(
239+
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
240+
PatternBenefit benefit = 1)
241+
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
242+
}
243+
244+
LogicalResult
245+
matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor,
246+
ConversionPatternRewriter &rewriter) const override {
247+
MLIRContext *ctx = op.getContext();
248+
249+
const auto &shape = op.getType().getShape();
250+
auto srcTy = op.getSrc().getType();
251+
auto dstTy = op.getType();
252+
auto dstDotEncoding = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
253+
if (!dstDotEncoding)
254+
return failure();
255+
if (!isa<BlockedEncodingAttr>(srcTy.getEncoding()) ||
256+
!isa<BlockedEncodingAttr>(dstDotEncoding.getParent()))
257+
return failure();
258+
if (cvtNeedsSharedMemory(srcTy, dstTy))
259+
return failure();
260+
rewriter.replaceOp(op, adaptor.getSrc());
261+
return success();
262+
}
263+
};
264+
235265
struct ConvertLayoutOpUsingLinearLayoutsConversion
236266
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
237267
const TargetInfoBase &targetInfo;
@@ -657,5 +687,7 @@ void mlir::triton::populateConvertLayoutOpToLLVMPatterns(
657687
// one left.
658688
mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern(
659689
typeConverter, targetInfo, patterns, benefit.getBenefit() + 1);
690+
patterns.add<ConvertLayoutOpBlockedToDotOpShortcutConversion>(
691+
typeConverter, targetInfo, benefit);
660692
patterns.add<ConvertLayoutOpConversion>(typeConverter, targetInfo, benefit);
661693
}

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
8383
OpBuilder builder(cvtOp);
8484
auto srcType = cast<RankedTensorType>(cvtOp.getSrc().getType());
8585
auto dstType = cast<RankedTensorType>(cvtOp.getType());
86+
if (!cvtNeedsSharedMemory(srcType, dstType))
87+
return;
8688
auto srcBlocked =
8789
dyn_cast<triton::gpu::BlockedEncodingAttr>(srcType.getEncoding());
8890
auto dstDotOp =

lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ class RewriteTensorPointerPass
413413
auto newForOp = builder.create<scf::ForOp>(op.getLoc(), op.getLowerBound(),
414414
op.getUpperBound(), op.getStep(),
415415
newIterOperands);
416+
newForOp->setAttrs(op->getAttrs());
416417

417418
// Create value mapping. Note that for tensor pointers, we use identity
418419
// mapping. It may refer to a value in the old loop, but we will rewrite it

lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,8 @@ class TritonGPUReduceDataDuplicationPass
4242
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
4343
if (!dstDotOp)
4444
return;
45-
if (auto srcMmaEncoding =
46-
dyn_cast<triton::gpu::NvidiaMmaEncodingAttr>(srcEncoding)) {
47-
48-
if (srcMmaEncoding.getVersionMajor() != 2 ||
49-
(srcMmaEncoding.getWarpsPerCTA()[1] == 1 &&
50-
dstDotOp.getParent() == srcMmaEncoding))
51-
return;
52-
}
53-
if (auto srcMfmaEncoding =
54-
dyn_cast<triton::gpu::AMDMfmaEncodingAttr>(srcEncoding)) {
55-
56-
if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 &&
57-
srcMfmaEncoding.getIsTransposed() &&
58-
dstDotOp.getParent() == srcMfmaEncoding)
59-
return;
60-
}
45+
if (!cvtNeedsSharedMemory(srcType, dstType))
46+
return;
6147
auto srcOrder = triton::gpu::getOrder(srcEncoding);
6248
auto rank = srcOrder.size();
6349
SmallVector<unsigned> sharedOrder;

python/test/unit/language/test_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3375,7 +3375,8 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_
33753375
if is_hip():
33763376
# hip does not support tf32 precision, so use ieee for all tests
33773377
input_precision = "ieee"
3378-
if "gfx11" in triton.runtime.driver.active.get_current_target().arch:
3378+
arch = triton.runtime.driver.active.get_current_target().arch
3379+
if "gfx11" in arch or "gfx12" in arch:
33793380
if in_dtype_str == "float32":
33803381
pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d")
33813382
if out_dtype_str == "float16":
Lines changed: 88 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,105 @@
1-
// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx1130 | FileCheck %s
1+
// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions | FileCheck %s
22

3-
// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
4-
// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
5-
// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}>
6-
// CHECK: wmma_to_wmma_dot_op
3+
// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
4+
// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
5+
// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}>
6+
// CHECK-LABEL: wmma_to_wmma_dot_op
77
#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}>
8-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
8+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} {
99
tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) {
10-
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[WMMA]]> -> tensor<16x16xf16, #[[BLOCKED]]>
11-
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[SHARED]], #triton_gpu.shared_memory>
12-
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>>
10+
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]>
11+
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory>
12+
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>>
1313
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
1414
tt.return
1515
}
1616
}
1717

1818
// -----
1919

20-
// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
21-
// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
22-
// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}>
23-
// CHECK: wmma_to_wmma_dot3d_op
20+
// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
21+
// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
22+
// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}>
23+
// CHECK-LABEL: wmma_to_wmma_dot3d_op
2424
#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}>
2525
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
2626
tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) {
27-
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[WMMA]]> -> tensor<2x16x16xf16, #[[BLOCKED]]>
28-
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[SHARED]], #triton_gpu.shared_memory>
29-
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>>
27+
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]>
28+
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory>
29+
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>>
3030
%0 = triton_gpu.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
3131
tt.return
3232
}
3333
}
34+
35+
// -----
36+
37+
// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx1130
38+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}>
39+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} {
40+
tt.func @blocked_to_dot_op_shortcut_gfx1130(%arg0: tensor<32x32xf16, #blocked>) {
41+
// CHECK-NOT: triton_gpu.local_alloc
42+
// CHECK: triton_gpu.convert_layout
43+
// CHECK-NOT: triton_gpu.local_alloc
44+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
45+
tt.return
46+
}
47+
}
48+
49+
// -----
50+
51+
// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx940
52+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
53+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
54+
tt.func @blocked_to_dot_op_shortcut_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
55+
// CHECK-NOT: triton_gpu.local_alloc
56+
// CHECK: triton_gpu.convert_layout
57+
// CHECK-NOT: triton_gpu.local_alloc
58+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
59+
tt.return
60+
}
61+
}
62+
63+
// -----
64+
65+
// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_elems_gfx940
66+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
67+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
68+
tt.func @neg_blocked_to_dot_op_incompatible_elems_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
69+
// CHECK-NOT: triton_gpu.convert_layout
70+
// CHECK: triton_gpu.local_alloc
71+
// CHECK: triton_gpu.local_load
72+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
73+
tt.return
74+
}
75+
}
76+
77+
// -----
78+
79+
// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_threads_gfx940
80+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
81+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0]}>
82+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
83+
tt.func @neg_blocked_to_dot_op_incompatible_threads_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
84+
// CHECK-NOT: triton_gpu.convert_layout
85+
// CHECK: triton_gpu.local_alloc
86+
// CHECK: triton_gpu.local_load
87+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
88+
tt.return
89+
}
90+
}
91+
92+
// -----
93+
94+
// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_warp_gfx940
95+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
96+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
97+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
98+
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
99+
// CHECK-NOT: triton_gpu.convert_layout
100+
// CHECK: triton_gpu.local_alloc
101+
// CHECK: triton_gpu.local_load
102+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
103+
tt.return
104+
}
105+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s
2+
3+
// CHECK-LABEL: blocked_to_dot_op_shortcut_warp32
4+
#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [0, 1]}>
5+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
6+
tt.func @blocked_to_dot_op_shortcut_warp32(%arg0: tensor<32x32xf16, #blocked>, %arg1: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) {
7+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
8+
// CHECK-NOT: load
9+
tt.return
10+
}
11+
}
12+
13+
// -----
14+
15+
// CHECK-LABEL: blocked_to_dot_op_shortcut_warp64
16+
#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [2, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
17+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
18+
tt.func @blocked_to_dot_op_shortcut_warp64(%arg0: tensor<32x32xf16, #blocked>) {
19+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
20+
// CHECK-NOT: load
21+
tt.return
22+
}
23+
}
24+
25+
// -----
26+
27+
// CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp32
28+
#blocked = #triton_gpu.blocked<{sizePerThread = [2, 32, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}>
29+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
30+
tt.func @blocked_to_dot3d_op_shortcut_warp32(%arg0: tensor<8x32x32xf16, #blocked>) {
31+
%0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
32+
// CHECK-NOT: load
33+
tt.return
34+
}
35+
}
36+
37+
// -----
38+
39+
// CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp64
40+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32, 1], threadsPerWarp = [1, 2, 32], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}>
41+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
42+
tt.func @blocked_to_dot3d_op_shortcut_warp64(%arg0: tensor<8x32x32xf16, #blocked>) {
43+
%0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
44+
// CHECK-NOT: load
45+
tt.return
46+
}
47+
}

0 commit comments

Comments
 (0)