Skip to content

Commit 4f6f088

Browse files
Merge commit 'cc89dac07b7acf3af9962d83250a8bc015fc5a91'
2 parents 2350d5a + cc89dac commit 4f6f088

File tree

9 files changed

+198
-26
lines changed

9 files changed

+198
-26
lines changed

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
374374
// TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere)
375375
// -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be
376376
// completed before we can remove the layoutIsOK check:
377-
// 1. Support for AMD's WMMA
377+
// 1. Support for AMD's WMMA dot operand
378378
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
379-
if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
379+
if (isa<MmaEncodingTrait>(layout)) {
380380
return !useLegacyMMAConversion;
381381
}
382382
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ void lowerDistributedToShared(
2323
auto srcTy = cast<RankedTensorType>(src.getType());
2424
auto dstTy = cast<MemDescType>(dst.getType());
2525
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
26-
assert(srcTy.getShape().size() <= 2 ||
27-
(srcTy.getShape().size() == 3 && outOrd[2] == 0) &&
28-
"Unexpected rank of ConvertLayout(blocked->shared)");
2926
auto elemTy = typeConverter->convertType(srcTy.getElementType());
3027

3128
auto smemBase = smemObj.getBase();
@@ -163,7 +160,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
163160
srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth;
164161
// To be removed in https://github.com/triton-lang/triton/pull/5154
165162
bool legacyLoweringIsBuggy =
166-
(kWidth >= 8 || (kWidth == 4 && bitwidth == 32)) && mma.isAmpere();
163+
(kWidth >= 8 || (kWidth == 4 && bitwidth == 32) ||
164+
dstTy.getRank() == 3) &&
165+
mma.isAmpere();
167166
return (mma.isHopper() && !canUseLdmatrix) ||
168167
(mma.isAmpere() && legacyLoweringIsBuggy);
169168
}
@@ -220,7 +219,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
220219
auto dstTy = op.getResult().getType();
221220
auto dstShape = dstTy.getShape();
222221
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
223-
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(srcTy, dstTy)) &&
222+
assert((!isa<DotOperandEncodingAttr>(dstTy.getEncoding()) ||
223+
isSupportedDotOpLayout(srcTy, dstTy)) &&
224224
"Unexpected rank of ConvertLayout(shared->distributed)");
225225

226226
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,12 @@ warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
113113
const SmallVector<unsigned, 3> &instrShape) {
114114
SetVector<Operation *> slices;
115115
mlir::getForwardSlice(dotOp.getResult(), &slices);
116-
if (llvm::find_if(slices, [](Operation *op) { return isa<DotOp>(op); }) !=
117-
slices.end())
116+
// Contains a chained dot. We prefer to assign warps to one axis
117+
// to facilitate use cases like flash attention, allowing reductions within
118+
// the same warp.
119+
if (llvm::find_if(slices, [](Operation *op) {
120+
return op->hasTrait<OpTrait::DotLike>();
121+
}) != slices.end())
118122
return {(unsigned)numWarps, 1};
119123

120124
// For MMAv3, the smallest indivisible unit of warp shape is (4, 1).

python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def build_extension(self, ext):
493493
"-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DLLVM_ENABLE_WERROR=ON",
494494
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_TUTORIALS=OFF",
495495
"-DTRITON_BUILD_PYTHON_MODULE=ON", "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable,
496-
"-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON", "-DPython3_INCLUDE_DIR=" + python_include_dir,
496+
"-DPython3_INCLUDE_DIR=" + python_include_dir,
497497
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
498498
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
499499
]

python/test/unit/language/test_core.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5433,6 +5433,97 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
54335433
assert torch.equal(z, x)
54345434

54355435

5436+
layouts_3d = [
5437+
BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
5438+
BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
5439+
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]), op_idx=0,
5440+
k_width=1),
5441+
]
5442+
5443+
shared_layout_3d = [
5444+
SharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
5445+
SharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
5446+
SharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
5447+
SharedLayout(4, 2, 1, [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
5448+
]
5449+
5450+
5451+
@pytest.mark.parametrize("M, N, K", [[8, 16, 32]])
5452+
@pytest.mark.parametrize("shared_layout", shared_layout_3d)
5453+
@pytest.mark.parametrize("dist_layout", layouts_3d)
5454+
def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: pathlib.Path):
5455+
layouts = f"""
5456+
#dist = {dist_layout}
5457+
#shared = {shared_layout}
5458+
"""
5459+
ir = layouts + f"""
5460+
module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
5461+
tt.func public @kernel(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
5462+
%cst = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist>
5463+
%cst_0 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist>
5464+
%cst_1 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist>
5465+
%cst_2 = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist>
5466+
%0 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>>
5467+
%1 = tt.expand_dims %0 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
5468+
%2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist>
5469+
%3 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x1x{K}x!tt.ptr<i32>, #dist>
5470+
%4 = tt.addptr %3, %2 : tensor<1x1x{K}x!tt.ptr<i32>, #dist>, tensor<1x1x{K}xi32, #dist>
5471+
%5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5472+
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5473+
%7 = tt.expand_dims %6 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist>
5474+
%8 = arith.muli %7, %cst_2 : tensor<1x{N}x1xi32, #dist>
5475+
%9 = tt.broadcast %4 : tensor<1x1x{K}x!tt.ptr<i32>, #dist> -> tensor<1x{N}x{K}x!tt.ptr<i32>, #dist>
5476+
%10 = tt.broadcast %8 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist>
5477+
%11 = tt.addptr %9, %10 : tensor<1x{N}x{K}x!tt.ptr<i32>, #dist>, tensor<1x{N}x{K}xi32, #dist>
5478+
%12 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5479+
%13 = tt.expand_dims %12 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5480+
%14 = tt.expand_dims %13 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist>
5481+
%15 = arith.muli %14, %cst_1 : tensor<{M}x1x1xi32, #dist>
5482+
%16 = tt.broadcast %11 : tensor<1x{N}x{K}x!tt.ptr<i32>, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>
5483+
%17 = tt.broadcast %15 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist>
5484+
%18 = tt.addptr %16, %17 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>, tensor<{M}x{N}x{K}xi32, #dist>
5485+
%19 = tt.load %18 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>
5486+
%20 = ttg.local_alloc %19 : (tensor<{M}x{N}x{K}xi32, #dist>) -> !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory>
5487+
%21 = ttg.local_load %20 : !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory> -> tensor<{M}x{N}x{K}xi32, #dist>
5488+
%22 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>>
5489+
%23 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
5490+
%24 = tt.expand_dims %23 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist>
5491+
%25 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1x1x{K}x!tt.ptr<i32>, #dist>
5492+
%26 = tt.addptr %25, %24 : tensor<1x1x{K}x!tt.ptr<i32>, #dist>, tensor<1x1x{K}xi32, #dist>
5493+
%27 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5494+
%28 = tt.expand_dims %27 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5495+
%29 = tt.expand_dims %28 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist>
5496+
%30 = arith.muli %29, %cst : tensor<1x{N}x1xi32, #dist>
5497+
%31 = tt.broadcast %26 : tensor<1x1x{K}x!tt.ptr<i32>, #dist> -> tensor<1x{N}x{K}x!tt.ptr<i32>, #dist>
5498+
%32 = tt.broadcast %30 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist>
5499+
%33 = tt.addptr %31, %32 : tensor<1x{N}x{K}x!tt.ptr<i32>, #dist>, tensor<1x{N}x{K}xi32, #dist>
5500+
%34 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5501+
%35 = tt.expand_dims %34 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5502+
%36 = tt.expand_dims %35 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist>
5503+
%37 = arith.muli %36, %cst_0 : tensor<{M}x1x1xi32, #dist>
5504+
%38 = tt.broadcast %33 : tensor<1x{N}x{K}x!tt.ptr<i32>, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>
5505+
%39 = tt.broadcast %37 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist>
5506+
%40 = tt.addptr %38, %39 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>, tensor<{M}x{N}x{K}xi32, #dist>
5507+
tt.store %40, %21 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>
5508+
tt.return
5509+
}}
5510+
}}
5511+
"""
5512+
5513+
if is_xpu() and isinstance(dist_layout, DotOperandLayout) and isinstance(dist_layout.parent, MmaLayout):
5514+
pytest.xfail("DotOperandLayout with MmaLayout is not supported in XPU")
5515+
5516+
x = torch.arange(0, M * N * K, device=device, dtype=torch.int32).reshape(M, N, K)
5517+
z = torch.empty_like(x, device=device)
5518+
5519+
temp_file = tmp_path / "test_local_load_store.ttgir"
5520+
temp_file.write_text(ir)
5521+
kernel = triton.compile(str(temp_file))
5522+
5523+
kernel[(1, 1, 1)](x, z)
5524+
assert torch.equal(z, x)
5525+
5526+
54365527
mma_pairs = [
54375528
[
54385529
MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),

python/triton/compiler/code_generator.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import ast
22
import inspect
33
import re
4-
import sys
54
import warnings
65
import os
76
import textwrap
@@ -1176,17 +1175,6 @@ def visit_BoolOp(self, node: ast.BoolOp):
11761175

11771176
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
11781177

1179-
if sys.version_info < (3, 8):
1180-
1181-
def visit_NameConstant(self, node):
1182-
return constexpr(node.value)
1183-
1184-
def visit_Num(self, node):
1185-
return constexpr(node.n)
1186-
1187-
def visit_Str(self, node):
1188-
return constexpr(ast.literal_eval(node))
1189-
11901178
def visit_Attribute(self, node):
11911179
lhs = self.visit(node.value)
11921180
if _is_triton_tensor(lhs) and node.attr == "T":

test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1100 | FileCheck %s
22

3+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
34
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
45
#mma1 = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}>
56
#mma2 = #ttg.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}>
@@ -97,6 +98,70 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
9798
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
9899
tt.return
99100
}
101+
102+
// CHECK-LABEL: blocked_to_wmma1
103+
tt.func @blocked_to_wmma1(%arg0: tensor<128x16xi32, #blocked>) {
104+
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
105+
// CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
106+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma1>
107+
tt.return
108+
}
109+
110+
// CHECK-LABEL: slice_blocked_to_wmma1
111+
tt.func @slice_blocked_to_wmma1(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) {
112+
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
113+
// CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)>
114+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>>
115+
tt.return
116+
}
117+
118+
// CHECK-LABEL: wmma1_to_blocked
119+
tt.func @wmma1_to_blocked(%arg0: tensor<128x16xi32, #mma1>) {
120+
// CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
121+
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
122+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma1> -> tensor<128x16xi32, #blocked>
123+
tt.return
124+
}
125+
126+
// CHECK-LABEL: slice_wmma1_to_blocked
127+
tt.func @slice_wmma1_to_blocked(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>>) {
128+
// CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)>
129+
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
130+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
131+
tt.return
132+
}
133+
134+
// CHECK-LABEL: blocked_to_wmma2
135+
tt.func @blocked_to_wmma2(%arg0: tensor<128x16xi32, #blocked>) {
136+
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
137+
// CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
138+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma2>
139+
tt.return
140+
}
141+
142+
// CHECK-LABEL: slice_blocked_to_wmma2
143+
tt.func @slice_blocked_to_wmma2(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) {
144+
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
145+
// CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)>
146+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>>
147+
tt.return
148+
}
149+
150+
// CHECK-LABEL: wmma2_to_blocked
151+
tt.func @wmma2_to_blocked(%arg0: tensor<128x16xi32, #mma2>) {
152+
// CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
153+
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
154+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma2> -> tensor<128x16xi32, #blocked>
155+
tt.return
156+
}
157+
158+
// CHECK-LABEL: slice_wmma2_to_blocked
159+
tt.func @slice_wmma2_to_blocked(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>>) {
160+
// CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)>
161+
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
162+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
163+
tt.return
164+
}
100165
}
101166

102167
// -----

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,33 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-
7373

7474
// -----
7575

76+
// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}>
77+
// CHECK: #mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}>
78+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
79+
#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
80+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
81+
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} {
82+
// CHECK-LABEL: chained_dot
83+
tt.func public @chained_dot_wgmma(
84+
%arg0: tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
85+
%arg1: tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
86+
%arg2: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> {
87+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked>
88+
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1>
89+
// CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x64xf32, #mma>
90+
%d = tt.dot %arg0, %arg1, %cst_0 :
91+
tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked>
92+
%t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked>
93+
%c = ttg.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>>
94+
// CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x128xf32, #mma1>
95+
%r = tt.dot %c, %arg2, %cst_1 :
96+
tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1>
97+
tt.return %r : tensor<64x128xf32, #blocked1>
98+
}
99+
}
100+
101+
// -----
102+
76103
// CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
77104
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
78105
#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>

0 commit comments

Comments
 (0)