Skip to content

Commit cc89dac

Browse files
authored
Add tests for 3D local_load local_alloc and relax asserts (#5285)
Also switch 3D dot_operand cases to use linear layout path, This may be suboptimal in some cases but that solves the functionality problems which is more important. There is ongoing work from Mario that should get the code quality to be good again soon.
1 parent 1cb0d99 commit cc89dac

File tree

2 files changed

+93
-5
lines changed

2 files changed

+93
-5
lines changed

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ void lowerDistributedToShared(
2323
auto srcTy = cast<RankedTensorType>(src.getType());
2424
auto dstTy = cast<MemDescType>(dst.getType());
2525
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
26-
assert(srcTy.getShape().size() <= 2 ||
27-
(srcTy.getShape().size() == 3 && outOrd[2] == 0) &&
28-
"Unexpected rank of ConvertLayout(blocked->shared)");
2926
auto elemTy = typeConverter->convertType(srcTy.getElementType());
3027

3128
auto smemBase = smemObj.getBase();
@@ -163,7 +160,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
163160
srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth;
164161
// To be removed in https://github.com/triton-lang/triton/pull/5154
165162
bool legacyLoweringIsBuggy =
166-
(kWidth >= 8 || (kWidth == 4 && bitwidth == 32)) && mma.isAmpere();
163+
(kWidth >= 8 || (kWidth == 4 && bitwidth == 32) ||
164+
dstTy.getRank() == 3) &&
165+
mma.isAmpere();
167166
return (mma.isHopper() && !canUseLdmatrix) ||
168167
(mma.isAmpere() && legacyLoweringIsBuggy);
169168
}
@@ -220,7 +219,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
220219
auto dstTy = op.getResult().getType();
221220
auto dstShape = dstTy.getShape();
222221
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
223-
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(srcTy, dstTy)) &&
222+
assert((!isa<DotOperandEncodingAttr>(dstTy.getEncoding()) ||
223+
isSupportedDotOpLayout(srcTy, dstTy)) &&
224224
"Unexpected rank of ConvertLayout(shared->distributed)");
225225

226226
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(

python/test/unit/language/test_core.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5383,6 +5383,94 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
53835383
assert torch.equal(z, x)
53845384

53855385

5386+
layouts_3d = [
5387+
BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
5388+
BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
5389+
DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]), op_idx=0,
5390+
k_width=1),
5391+
]
5392+
5393+
shared_layout_3d = [
5394+
SharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
5395+
SharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
5396+
SharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
5397+
SharedLayout(4, 2, 1, [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
5398+
]
5399+
5400+
5401+
@pytest.mark.parametrize("M, N, K", [[8, 16, 32]])
5402+
@pytest.mark.parametrize("shared_layout", shared_layout_3d)
5403+
@pytest.mark.parametrize("dist_layout", layouts_3d)
5404+
def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: pathlib.Path):
5405+
layouts = f"""
5406+
#dist = {dist_layout}
5407+
#shared = {shared_layout}
5408+
"""
5409+
ir = layouts + f"""
5410+
module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
5411+
tt.func public @kernel(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
5412+
%cst = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist>
5413+
%cst_0 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist>
5414+
%cst_1 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist>
5415+
%cst_2 = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist>
5416+
%0 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>>
5417+
%1 = tt.expand_dims %0 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
5418+
%2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist>
5419+
%3 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x1x{K}x!tt.ptr<i32>, #dist>
5420+
%4 = tt.addptr %3, %2 : tensor<1x1x{K}x!tt.ptr<i32>, #dist>, tensor<1x1x{K}xi32, #dist>
5421+
%5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5422+
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5423+
%7 = tt.expand_dims %6 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist>
5424+
%8 = arith.muli %7, %cst_2 : tensor<1x{N}x1xi32, #dist>
5425+
%9 = tt.broadcast %4 : tensor<1x1x{K}x!tt.ptr<i32>, #dist> -> tensor<1x{N}x{K}x!tt.ptr<i32>, #dist>
5426+
%10 = tt.broadcast %8 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist>
5427+
%11 = tt.addptr %9, %10 : tensor<1x{N}x{K}x!tt.ptr<i32>, #dist>, tensor<1x{N}x{K}xi32, #dist>
5428+
%12 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5429+
%13 = tt.expand_dims %12 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5430+
%14 = tt.expand_dims %13 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist>
5431+
%15 = arith.muli %14, %cst_1 : tensor<{M}x1x1xi32, #dist>
5432+
%16 = tt.broadcast %11 : tensor<1x{N}x{K}x!tt.ptr<i32>, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>
5433+
%17 = tt.broadcast %15 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist>
5434+
%18 = tt.addptr %16, %17 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>, tensor<{M}x{N}x{K}xi32, #dist>
5435+
%19 = tt.load %18 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>
5436+
%20 = ttg.local_alloc %19 : (tensor<{M}x{N}x{K}xi32, #dist>) -> !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory>
5437+
%21 = ttg.local_load %20 : !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory> -> tensor<{M}x{N}x{K}xi32, #dist>
5438+
%22 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>>
5439+
%23 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
5440+
%24 = tt.expand_dims %23 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist>
5441+
%25 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1x1x{K}x!tt.ptr<i32>, #dist>
5442+
%26 = tt.addptr %25, %24 : tensor<1x1x{K}x!tt.ptr<i32>, #dist>, tensor<1x1x{K}xi32, #dist>
5443+
%27 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5444+
%28 = tt.expand_dims %27 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5445+
%29 = tt.expand_dims %28 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist>
5446+
%30 = arith.muli %29, %cst : tensor<1x{N}x1xi32, #dist>
5447+
%31 = tt.broadcast %26 : tensor<1x1x{K}x!tt.ptr<i32>, #dist> -> tensor<1x{N}x{K}x!tt.ptr<i32>, #dist>
5448+
%32 = tt.broadcast %30 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist>
5449+
%33 = tt.addptr %31, %32 : tensor<1x{N}x{K}x!tt.ptr<i32>, #dist>, tensor<1x{N}x{K}xi32, #dist>
5450+
%34 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>>
5451+
%35 = tt.expand_dims %34 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>>
5452+
%36 = tt.expand_dims %35 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist>
5453+
%37 = arith.muli %36, %cst_0 : tensor<{M}x1x1xi32, #dist>
5454+
%38 = tt.broadcast %33 : tensor<1x{N}x{K}x!tt.ptr<i32>, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>
5455+
%39 = tt.broadcast %37 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist>
5456+
%40 = tt.addptr %38, %39 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>, tensor<{M}x{N}x{K}xi32, #dist>
5457+
tt.store %40, %21 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>
5458+
tt.return
5459+
}}
5460+
}}
5461+
"""
5462+
5463+
x = torch.arange(0, M * N * K, device=device, dtype=torch.int32).reshape(M, N, K)
5464+
z = torch.empty_like(x, device=device)
5465+
5466+
temp_file = tmp_path / "test_local_load_store.ttgir"
5467+
temp_file.write_text(ir)
5468+
kernel = triton.compile(str(temp_file))
5469+
5470+
kernel[(1, 1, 1)](x, z)
5471+
assert torch.equal(z, x)
5472+
5473+
53865474
mma_pairs = [
53875475
[
53885476
MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),

0 commit comments

Comments
 (0)