Skip to content

Commit 34676a2

Browse files
authored
[AMD] Add gfx1250 support for ds_read_tr (#8461)
In gfx1250 the instruction has different size based on datatype (b128 for b16 types and b64 for b8 types). Currently the ROCDL transpose instructions aren't upstream so this is using LLVM intrinsics. Once the ROCDL instructions are available we will use those. As part of the change I've refactored the ds_read_tr parameter setup
1 parent 3c8c0cf commit 34676a2

File tree

7 files changed

+347
-65
lines changed

7 files changed

+347
-65
lines changed

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -538,14 +538,20 @@ chooseLLDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
538538
auto kRegister = S("register");
539539
auto kLane = S("lane");
540540

541-
// Make sure that we have enough bases to rotate, otherwise we can't return
542-
// a valid ds_read_tr layout
543-
if (ldsTransLayout.getInDimSizeLog2(kRegister) < numRegBases ||
544-
ldsTransLayout.getInDimSizeLog2(kLane) < numLaneBases) {
541+
// Make sure that we have enough register bases to rotate, otherwise we
542+
// can't return a valid ds_read_tr layout
543+
if (ldsTransLayout.getInDimSizeLog2(kRegister) < numRegBases) {
545544
return std::nullopt;
546545
}
547-
546+
// We should always have enough lanes
547+
assert(ldsTransLayout.getInDimSizeLog2(kLane) >= numLaneBases);
548548
rotatePrefixes(bases[kRegister], numRegBases, bases[kLane], numLaneBases);
549+
// Scale types double the elements for a total of 16 vgpr (still only 16
550+
// elements contiguous). Need to adjust the lane basis to reflect that
551+
if (elemBitWidth == 8 && numLanesInShuffleGroup == 8) {
552+
assert(ldsTransLayout.getInDimSizeLog2(kLane) >= (numLaneBases + 1));
553+
std::swap(bases[kLane][numLaneBases - 1], bases[kLane][numLaneBases]);
554+
}
549555

550556
return LinearLayout(bases, ldsTransLayout.getOutDims(), false);
551557
}
@@ -1127,7 +1133,7 @@ LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
11271133
// addressable blocks If the zero is in any other row/col (i.e. within a given
11281134
// warp-addressable tmem space) it means it is not defined
11291135

1130-
// We model packed layouts as having the rows/cols dimensions of bitwidth=16
1136+
// We model packed layouts as having the rows/cols dimensions of bitWidth=16
11311137
// This means that a layout with unpacked=True is the same as one with
11321138
// unpacked=False
11331139
assert(shape.size() == 2);
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm=arch=gfx1250 --convert-builtin-func-to-llvm | FileCheck %s
2+
3+
#mma_b16 = #ttg.amd_wmma<{version = 3, warpsPerCTA = [2, 2], instrShape = [16, 16, 32]}> // b16
4+
#mma_b8 = #ttg.amd_wmma<{version = 3, warpsPerCTA = [2, 2], instrShape = [16, 16, 64]}> // b8
5+
#mma_b8_2x = #ttg.amd_wmma<{version = 3, warpsPerCTA = [2, 2], instrShape = [16, 16, 128]}> // b8
6+
#linear_ds_tr = #ttg.linear<{register = [[0, 64], [16, 0], [0, 1], [32, 0], [0, 2], [0, 4], [64, 0], [0, 8], [0, 32]],
7+
lane = [[1, 0], [2, 0], [4, 0], [0, 16], [8, 0]], warp = [[0, 0], [0, 0]], block = []}>
8+
9+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
10+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
11+
#padding = #ttg.padded_shared<[512:+16] {order = [0, 1], shape = [128, 64]}>
12+
#padding_vec1 = #ttg.padded_shared<[1:+4] {order = [0, 1], shape = [128, 64]}>
13+
#smem = #ttg.shared_memory
14+
15+
#linear_ds_tr_tile_out = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}>
16+
#linear_ds_tr_tile_invalid = #ttg.linear<{register = [[0, 1], [0, 2], [0, 8], [0, 4]], lane = [[1, 0], [4, 0], [2, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}>
17+
18+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
19+
// CHECK-LABEL: b16_tests
20+
tt.func @b16_tests(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
21+
// CHECK-COUNT-32: llvm.call_intrinsic "llvm.amdgcn.ds.load.tr16.b128"(%{{.*}}) : (!llvm.ptr<3>) -> vector<8xf16>
22+
// CHECK-NOT: ds.load.tr16.b128
23+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
24+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma_b16, kWidth = 8}>>
25+
26+
%ptr1 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
27+
%ptr2 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<64x128x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma_b16, kWidth = 8}>>
28+
tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
29+
tt.store %ptr2, %2 : tensor<64x128x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma_b16, kWidth = 8}>>
30+
tt.return
31+
}
32+
// CHECK-LABEL: b16_tests_with_neg
33+
tt.func @b16_tests_with_neg(%arg0: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
34+
// CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16>
35+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
36+
// CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.ds.load.tr16.b128"(%{{.*}}) : (!llvm.ptr<3>) -> vector<8xf16>
37+
// CHECK-NOT: ds.load.tr16.b128
38+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma_b16, kWidth = 8}>>
39+
40+
%ptr1 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
41+
%ptr2 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<64x128x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma_b16, kWidth = 8}>>
42+
tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
43+
tt.store %ptr2, %2 : tensor<64x128x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 1, parent = #mma_b16, kWidth = 8}>>
44+
tt.return
45+
}
46+
47+
// CHECK-LABEL: b8_tests
48+
tt.func @b8_tests(%arg0: !ttg.memdesc<128x64xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared1, #smem, mutable>, %arg2: !tt.ptr<i8> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
49+
// CHECK-COUNT-48: llvm.call_intrinsic "llvm.amdgcn.ds.load.tr8.b64"(%{{.*}}) : (!llvm.ptr<3>) -> vector<2xi32>
50+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma_b8_2x, kWidth = 16}>>
51+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared1, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma_b8, kWidth = 8}>>
52+
// CHECK-NOT: ds.load.tr8.b64
53+
%ptr1 = tt.splat %arg2 : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #ttg.dot_op<{opIdx = 0, parent = #mma_b8_2x, kWidth = 16}>>
54+
%ptr2 = tt.splat %arg2 : !tt.ptr<i8> -> tensor<64x128x!tt.ptr<i8>, #ttg.dot_op<{opIdx = 1, parent = #mma_b8, kWidth = 8}>>
55+
tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<i8>, #ttg.dot_op<{opIdx = 0, parent = #mma_b8_2x, kWidth = 16}>>
56+
tt.store %ptr2, %2 : tensor<64x128x!tt.ptr<i8>, #ttg.dot_op<{opIdx = 1, parent = #mma_b8, kWidth = 8}>>
57+
tt.return
58+
}
59+
60+
// CHECK-LABEL: no_ds_read_tr
61+
tt.func @no_ds_read_tr(%arg0: !ttg.memdesc<128x64xi8, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared, #smem, mutable>, %arg2: !tt.ptr<i8> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
62+
// CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xi8>
63+
// CHECK-NOT: ds.load.tr8.b64
64+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared1, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma_b8_2x, kWidth = 16}>>
65+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma_b8, kWidth = 8}>>
66+
67+
%ptr1 = tt.splat %arg2 : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #ttg.dot_op<{opIdx = 0, parent = #mma_b8_2x, kWidth = 16}>>
68+
%ptr2 = tt.splat %arg2 : !tt.ptr<i8> -> tensor<64x128x!tt.ptr<i8>, #ttg.dot_op<{opIdx = 1, parent = #mma_b8, kWidth = 8}>>
69+
tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<i8>, #ttg.dot_op<{opIdx = 0, parent = #mma_b8_2x, kWidth = 16}>>
70+
tt.store %ptr2, %2 : tensor<64x128x!tt.ptr<i8>, #ttg.dot_op<{opIdx = 1, parent = #mma_b8, kWidth = 8}>>
71+
tt.return
72+
}
73+
74+
// CHECK-LABEL: ds_transpose_ll
75+
tt.func @ds_transpose_ll(%arg0: !ttg.memdesc<64x16xbf16, #shared, #smem>, %arg1: !tt.ptr<bf16>) {
76+
// CHECK-COUNT-4: llvm.call_intrinsic "llvm.amdgcn.ds.load.tr16.b128"(%{{.*}}) : (!llvm.ptr<3>) -> vector<8xbf16>
77+
// CHECK-NOT: ds.load.tr16.b128
78+
%a1 = ttg.local_load %arg0 : !ttg.memdesc<64x16xbf16, #shared, #smem> -> tensor<64x16xbf16, #linear_ds_tr_tile_out>
79+
80+
%ptr1 = tt.splat %arg1 : !tt.ptr<bf16> -> tensor<64x16x!tt.ptr<bf16>, #linear_ds_tr_tile_out>
81+
tt.store %ptr1, %a1 : tensor<64x16x!tt.ptr<bf16>, #linear_ds_tr_tile_out>
82+
tt.return
83+
}
84+
85+
// CHECK-LABEL: ds_transpose_ll_complex
86+
tt.func @ds_transpose_ll_complex(%arg0: !ttg.memdesc<64x16xbf16, #shared, #smem>, %arg1: !tt.ptr<bf16>) {
87+
// CHECK-COUNT-8: llvm.call_intrinsic "llvm.amdgcn.ds.load.tr16.b128"(%{{.*}}) : (!llvm.ptr<3>) -> vector<8xbf16>
88+
%a1 = ttg.local_load %arg0 : !ttg.memdesc<64x16xbf16, #shared, #smem> -> tensor<64x16xbf16, #linear_ds_tr>
89+
90+
%ptr1 = tt.splat %arg1 : !tt.ptr<bf16> -> tensor<64x16x!tt.ptr<bf16>, #linear_ds_tr>
91+
tt.store %ptr1, %a1 : tensor<64x16x!tt.ptr<bf16>, #linear_ds_tr>
92+
tt.return
93+
}
94+
95+
// CHECK-LABEL: ds_transpose_ll_invalid
96+
tt.func @ds_transpose_ll_invalid(%arg0: !ttg.memdesc<64x16xbf16, #shared, #smem>, %arg1: !tt.ptr<bf16>) {
97+
%a1 = ttg.local_load %arg0 : !ttg.memdesc<64x16xbf16, #shared, #smem> -> tensor<64x16xbf16, #linear_ds_tr_tile_invalid>
98+
// CHECK-NOT: ds.load.tr16.b128
99+
%ptr1 = tt.splat %arg1 : !tt.ptr<bf16> -> tensor<64x16x!tt.ptr<bf16>, #linear_ds_tr_tile_invalid>
100+
tt.store %ptr1, %a1 : tensor<64x16x!tt.ptr<bf16>, #linear_ds_tr_tile_invalid>
101+
tt.return
102+
}
103+
104+
// CHECK-LABEL: ds_transpose_with_padding
105+
tt.func @ds_transpose_with_padding(%arg0: !ttg.memdesc<128x64xf16, #padding, #smem, mutable>, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
106+
// CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.ds.load.tr16.b128"(%{{.*}}) : (!llvm.ptr<3>) -> vector<8xf16>
107+
// CHECK-NOT: ds.load.tr16.b128
108+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #padding, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
109+
110+
%ptr1 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
111+
tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
112+
tt.return
113+
}
114+
115+
// CHECK-LABEL: ds_transpose_padding_interval_too_small
116+
tt.func @ds_transpose_padding_interval_too_small(%arg0: !ttg.memdesc<128x64xf16, #padding_vec1, #smem, mutable>, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
117+
// CHECK-NOT: ds.load.tr16.b128
118+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #padding_vec1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
119+
120+
%ptr1 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
121+
tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma_b16, kWidth = 8}>>
122+
tt.return
123+
}
124+
}

test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
4646
tt.func @wmma3_dot_operand_bf16(%arg0: !ttg.memdesc<64x64xbf16, #shared, #smem>, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
4747
// GFX1250-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xbf16>
4848
%0 = ttg.local_load %arg0 : !ttg.memdesc<64x64xbf16, #shared, #smem> -> tensor<64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma3, kWidth = 8}>>
49-
// GFX1250-COUNT-64: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xbf16>
49+
// GFX1250-COUNT-8: llvm.call_intrinsic "llvm.amdgcn.ds.load.tr16.b128"(%{{.*}}) : (!llvm.ptr<3>) -> vector<8xbf16>
5050
%1 = ttg.local_load %arg0 : !ttg.memdesc<64x64xbf16, #shared, #smem> -> tensor<64x64xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma3, kWidth = 8}>>
5151

5252
%ptr0 = tt.splat %arg1 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>, #ttg.dot_op<{opIdx = 0, parent = #mma3, kWidth = 8}>>

0 commit comments

Comments
 (0)