Skip to content

Commit 03ff0ff

Browse files
authored
Merge branch 'main' into fast-sub-group-transpose-extend
2 parents 5e2029d + a5140b7 commit 03ff0ff

File tree

9 files changed

+908
-31
lines changed

9 files changed

+908
-31
lines changed

.github/workflows/auto-update-translator-cid.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ jobs:
8686
- name: Search the latest valid Translator cid
8787
if: ${{ env.TARGET_PRID == null }}
8888
run: |
89-
env
9089
./scripts/check-update-translator-cid.sh $CID_LATEST $CID_CURRENT
9190
if git status --porcelain ./lib/Target/SPIRV/spirv-llvm-translator.conf | grep '^ M'; then
9291
echo "MODIFIED=true" >> $GITHUB_ENV

scripts/check-update-translator-cid.sh

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,22 @@ cd $TRITON_PROJ
2626
FOUND=false
2727
for cid in $COMMIT_IDS; do
2828
echo "$cid" > ./lib/Target/SPIRV/spirv-llvm-translator.conf
29-
if ! ./scripts/compile-triton.sh --clean; then
30-
echo "Triton compile failed for translator commit $cid"
29+
30+
BUILD_STATUS=PASS
31+
echo "::group::Building Triton for $cid"
32+
./scripts/compile-triton.sh --clean || BUILD_STATUS=FAIL
33+
echo "::endgroup::"
34+
35+
if [ $BUILD_STATUS != PASS ]; then
3136
continue
3237
fi
3338

34-
# execute default tests
35-
if ./scripts/test-triton.sh --skip-pytorch-install; then
39+
TEST_STATUS=PASS
40+
echo "::group::Testing Triton for $cid"
41+
./scripts/test-triton.sh --skip-pytorch-install || TEST_STATUS=FAIL
42+
echo "::endgroup::"
43+
44+
if [ $TEST_STATUS = PASS ]; then
3645
echo "Tests passed for translator commit $cid"
3746
echo "A newer commit found: $cid"
3847
FOUND=true

scripts/compile-triton.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ build_triton() {
179179

180180
cd python
181181
# Install triton and its dependencies.
182-
pip install -vvv -e '.[build,tests]'
182+
pip install -v -e '.[build,tests]'
183183

184184
# Copy compile_commands.json in the build directory (so that cland vscode plugin can find it).
185185
cp $(find $TRITON_PROJ_BUILD -name compile_commands.json) $TRITON_PROJ/

scripts/skiplist/mtl/.gitkeep

Whitespace-only changes.

scripts/skiplist/mtl/language.txt

Lines changed: 300 additions & 0 deletions
Large diffs are not rendered by default.

scripts/skiplist/mtl/tutorials.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
02-fused-softmax
2+
03-matrix-multiplication
3+
06-fused-attention
4+
08-grouped-gemm
5+
10-experimental-block-pointer
6+
10i-experimental-block-pointer

test/Conversion/intel/sub-group-shuffle.mlir

Lines changed: 259 additions & 0 deletions
Large diffs are not rendered by default.

test/Conversion/intel/sub-group-transpose.mlir

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,132 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16
297297
tt.return %0 : tensor<64x16x4xf32, #blocked1>
298298
}
299299
}
300+
301+
// -----
302+
303+
// Test transposition with 32 elements per work-item.
304+
305+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
306+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
307+
308+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
309+
// CHECK-LABEL: llvm.func spir_kernelcc @test(
310+
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
311+
tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> {
312+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
313+
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
314+
// CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
315+
// CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
316+
// CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
317+
// CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
318+
// CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
319+
// CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
320+
// CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
321+
// CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
322+
// CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
323+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
324+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
325+
// CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
326+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
327+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
328+
// CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
329+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
330+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
331+
// CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
332+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
333+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
334+
// CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
335+
// CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
336+
// CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
337+
// CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
338+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
339+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #blocked1>
340+
tt.return %0 : tensor<32x16xf32, #blocked1>
341+
}
342+
}
343+
344+
// -----
345+
346+
// Test transposition with 32 elements per work-item with a different layout.
347+
348+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
349+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
350+
351+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
352+
// CHECK-LABEL: llvm.func spir_kernelcc @test(
353+
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
354+
tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> {
355+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
356+
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
357+
// CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
358+
// CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
359+
// CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
360+
// CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
361+
// CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
362+
// CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
363+
// CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
364+
// CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
365+
// CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
366+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
367+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
368+
// CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
369+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
370+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
371+
// CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
372+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
373+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
374+
// CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
375+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
376+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
377+
// CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
378+
// CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
379+
// CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
380+
// CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
381+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
382+
%0 = triton_gpu.convert_layout %arg0 : tensor<16x32xf32, #blocked> -> tensor<16x32xf32, #blocked1>
383+
tt.return %0 : tensor<16x32xf32, #blocked1>
384+
}
385+
}
386+
387+
// -----
388+
389+
// Test transposition with 32 elements per work-item and two warps in each dimension.
390+
391+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [0, 1]}>
392+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
393+
394+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
395+
// CHECK-LABEL: llvm.func spir_kernelcc @test(
396+
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
397+
tt.func @test(%arg0: tensor<32x64xf32, #blocked>) -> tensor<32x64xf32, #blocked1> {
398+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
399+
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
400+
// CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
401+
// CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
402+
// CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
403+
// CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
404+
// CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
405+
// CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
406+
// CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
407+
// CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
408+
// CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
409+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
410+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
411+
// CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
412+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
413+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
414+
// CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
415+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
416+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
417+
// CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
418+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
419+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
420+
// CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
421+
// CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
422+
// CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
423+
// CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
424+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
425+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x64xf32, #blocked> -> tensor<32x64xf32, #blocked1>
426+
tt.return %0 : tensor<32x64xf32, #blocked1>
427+
}
428+
}

0 commit comments

Comments
 (0)