Skip to content

Commit 95eb619

Browse files
Merge OpenAI Triton commit f3cef44 (#3824)
This PR change the Triton base from 52e7a4b to f3cef44 (Apr 1). Pass rate: 90.38%
2 parents 7a38bd8 + cc61bc8 commit 95eb619

File tree

21 files changed

+1331
-208
lines changed

21 files changed

+1331
-208
lines changed

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ void registerTestAllocationPass();
5050
void registerTestLivenessPass();
5151
void registerTestMembarPass();
5252
void registerTestTritonAMDGPURangeAnalysis();
53+
void registerTestTritonAMDGPUFoldTrueCmpIOp();
5354
} // namespace test
5455
} // namespace mlir
5556

@@ -65,6 +66,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6566
mlir::test::registerTestLivenessPass();
6667
mlir::test::registerTestMembarPass();
6768
mlir::test::registerTestTritonAMDGPURangeAnalysis();
69+
mlir::test::registerTestTritonAMDGPUFoldTrueCmpIOp();
6870
mlir::triton::registerConvertTritonToTritonGPUPass();
6971
mlir::triton::intel::registerConvertTritonToTritonGPUWarpPass();
7072
mlir::triton::intel::registerTritonIntelTensorDescToBlockPointer();

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1d4801f22ab1fd6205b1cf625b690aefc554cd4c
1+
adba14acea99cc6a17d837763a3248c9d4a2fadf

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,12 @@ updateEncodingForShape(Operation *op, gpu::SharedEncodingTrait encoding,
8181
return swizEnc;
8282

8383
auto rank = tensorType.getRank();
84-
SmallVector<unsigned> order(
85-
swizEnc.getOrder().drop_front(swizEnc.getOrder().size() - rank));
84+
auto oldOrder = swizEnc.getOrder();
85+
assert(oldOrder.size() <= rank);
86+
SmallVector<unsigned> order;
87+
for (int i = 0; i + oldOrder.size() < rank; ++i)
88+
order.push_back(rank - i - 1);
89+
order.append(oldOrder.begin(), oldOrder.end());
8690
auto newCtaEnc = updateCTALayoutForShape(ctaLayout, tensorType.getShape());
8791
return gpu::SwizzledSharedEncodingAttr::get(
8892
ctx, swizEnc.getVec(), swizEnc.getPerPhase(), swizEnc.getMaxPhase(),

python/test/unit/language/test_core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4751,7 +4751,10 @@ def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr):
47514751
if is_interpreter():
47524752
return
47534753

4754-
assert 'llvm.assume' in pgm.asm['llir']
4754+
assert 'llvm.intr.assume' in pgm.asm['ttgir']
4755+
# stream pipeliner on AMD folds true cmpi ops to %true (Which llvm itself then dces)
4756+
if not is_hip():
4757+
assert 'llvm.assume' in pgm.asm['llir']
47554758

47564759

47574760
# ---------------

test/Conversion/amd/buffer_load_to_local_to_llvm.mlir

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,123 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
162162
tt.return
163163
}
164164
}
165+
166+
// -----
167+
168+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [8, 1], order = [1, 0]}>
169+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
170+
#smem = #ttg.shared_memory
171+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
172+
// COMMON-LABEL: buffer_load_swizzled_simple
173+
tt.func public @buffer_load_swizzled_simple(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
174+
%arg1: !tt.ptr<f16>,
175+
%arg2: tensor<16x64xi32, #blocked>,
176+
%arg3: !ttg.memdesc<16x64xf16, #shared, #smem, mutable>) {
177+
// Each thread needs to load 2 elements and we load 1 (sizePerThread) per buffer load instruction
178+
// COMMON: rocdl.make.buffer.rsrc
179+
// COMMON-NOT: rocdl.make.buffer.rsrc
180+
// COMMON: rocdl.ds_bpermute
181+
// COMMON: rocdl.raw.ptr.buffer.load.lds
182+
// COMMON: rocdl.ds_bpermute
183+
// COMMON: rocdl.raw.ptr.buffer.load.lds
184+
// COMMON-NOT: rocdl.raw.ptr.buffer.load.lds
185+
%65 = amdgpu.buffer_load_to_local %arg1[%arg2] into %arg3 {OpIdx = #amdgpu.OpIdx<1>} : <f16>[tensor<16x64xi32, #blocked>] -> <16x64xf16, #shared, #smem, mutable>
186+
tt.return
187+
}
188+
}
189+
190+
// -----
191+
192+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
193+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 2, maxPhase = 8, order = [1, 0]}>
194+
#smem = #ttg.shared_memory
195+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
196+
// COMMON-LABEL: buffer_load_to_local_swizzled_mask_other
197+
tt.func public @buffer_load_to_local_swizzled_mask_other(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
198+
%arg1: !tt.ptr<f16>,
199+
%arg2: tensor<32x32xi32, #blocked>,
200+
%arg3: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>,
201+
%arg4: i32) {
202+
// We need the splat to allow the AxisAnalysis to work during lowering
203+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked>
204+
%c0_i32 = arith.constant 0 : i32
205+
%c32_i32 = arith.constant 32 : i32
206+
%c31_i32 = arith.constant 31 : i32
207+
%1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>, #blocked>
208+
%29 = arith.addi %arg4, %c31_i32 : i32
209+
%30 = arith.divsi %29, %c32_i32 : i32
210+
%31 = arith.cmpi sgt, %30, %c0_i32 : i32
211+
212+
%51 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
213+
%52 = tt.expand_dims %51 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
214+
%65 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #blocked>
215+
%66 = arith.cmpi slt, %52, %65 : tensor<32x1xi32, #blocked>
216+
%67 = tt.broadcast %66 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
217+
218+
%70 = tt.splat %31 : i1 -> tensor<32x32xi1, #blocked>
219+
%71 = arith.andi %70, %67 : tensor<32x32xi1, #blocked>
220+
221+
// Each thread needs to load 4 elements and we load 1 (sizePerThread) per buffer load instruction
222+
// Note that mask/other alignment is 1 so we need 4 conditionals
223+
224+
// COMMON: rocdl.ds_bpermute
225+
// COMMON: rocdl.ballot
226+
// COMMON: rocdl.raw.ptr.buffer.load.lds
227+
// COMMON: _predicated_store
228+
229+
// COMMON: rocdl.ds_bpermute
230+
// COMMON: rocdl.ballot
231+
// COMMON: rocdl.raw.ptr.buffer.load.lds
232+
// COMMON: _predicated_store
233+
234+
// COMMON: rocdl.ds_bpermute
235+
// COMMON: rocdl.ballot
236+
// COMMON: rocdl.raw.ptr.buffer.load.lds
237+
// COMMON: _predicated_store
238+
239+
// COMMON: rocdl.ds_bpermute
240+
// COMMON: rocdl.ballot
241+
// COMMON: rocdl.raw.ptr.buffer.load.lds
242+
// COMMON: _predicated_store
243+
244+
// COMMON-NOT: rocdl.ds_bpermute
245+
// COMMON-NOT: rocdl.ballot
246+
// COMMON-NOT: rocdl.raw.ptr.buffer.load.lds
247+
// COMMON-NOT: _predicated_store
248+
249+
amdgpu.buffer_load_to_local %arg1[%arg2] mask=%67 other=%cst_0 into %arg3 {OpIdx = #amdgpu.OpIdx<1>} : <f16>[tensor<32x32xi32, #blocked>] tensor<32x32xf16, #blocked> -> <32x32xf16, #shared, #smem, mutable>
250+
tt.return
251+
}
252+
}
253+
254+
// -----
255+
256+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 32], order = [0, 1]}>
257+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 16, order = [0, 1]}>
258+
#smem = #ttg.shared_memory
259+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.shared = 0 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
260+
// COMMON-LABEL: buffer_load_to_local_swizzled_vectorized_8xf16
261+
tt.func public @buffer_load_to_local_swizzled_vectorized_8xf16(%arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>) {
262+
%cst = arith.constant dense<64> : tensor<1x64xi32, #blocked>
263+
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
264+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
265+
%2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
266+
%3 = tt.broadcast %2 : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked>
267+
%4 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
268+
%5 = arith.muli %4, %cst : tensor<1x64xi32, #blocked>
269+
%6 = tt.broadcast %5 : tensor<1x64xi32, #blocked> -> tensor<64x64xi32, #blocked>
270+
%7 = arith.addi %3, %6 : tensor<64x64xi32, #blocked>
271+
272+
// Each thread needs to load 8 elements and we load 8 (sizePerThread) per buffer load instruction
273+
// GFX950: rocdl.make.buffer.rsrc
274+
// GFX950: rocdl.ds_bpermute
275+
// GFX950: rocdl.raw.ptr.buffer.load.lds
276+
// GFX950-NOT: rocdl.raw.ptr.buffer.load.lds
277+
278+
// GFX942 does not support vectorization > 4bytes so we cannot lower it
279+
// GFX942-NOT: rocdl.raw.ptr.buffer.load.lds
280+
// GFX942: amdgpu.buffer_load_to_local
281+
%8 = amdgpu.buffer_load_to_local %arg1[%7] into %arg2 : <f16>[tensor<64x64xi32, #blocked>] -> <64x64xf16, #shared, #smem, mutable>
282+
tt.return
283+
}
284+
}

0 commit comments

Comments
 (0)