Skip to content

Commit 6588f0d

Browse files
authored
[XPU][TritonGEN] Replace subgroup reduce and scan (#2893)
Replace usage of the TritonGEN dialect subgroup reduce and scan operations with equivalent operations from the SPIR-V dialect. This closes #2892. --------- Signed-off-by: Lukas Sommer <[email protected]>
1 parent 78c13a5 commit 6588f0d

File tree

14 files changed

+175
-592
lines changed

14 files changed

+175
-592
lines changed

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,63 +1500,63 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "triton_
15001500
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
15011501
tt.func public @reduce_all(%arg: tensor<256x1xi32, #blocked>, %arg_0: tensor<256x1xf32, #blocked>) {
15021502

1503-
// CHECK: @_Z32sub_group_non_uniform_reduce_addf
1503+
// CHECK: @_Z27__spirv_GroupNonUniformFAddiif
15041504
%0 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
15051505
^bb0(%arg4: f32, %arg5: f32):
15061506
%48 = arith.addf %arg4, %arg5 : f32
15071507
tt.reduce.return %48 : f32
15081508
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>
15091509

1510-
// CHECK: @_Z32sub_group_non_uniform_reduce_addi
1510+
// CHECK: @_Z27__spirv_GroupNonUniformIAddiij
15111511
%1 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
15121512
^bb0(%arg4: i32, %arg5: i32):
15131513
%48 = arith.addi %arg4, %arg5 : i32
15141514
tt.reduce.return %48 : i32
15151515
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>
15161516

1517-
// CHECK: @_Z32sub_group_non_uniform_reduce_mulf
1517+
// CHECK: @_Z27__spirv_GroupNonUniformFMuliif
15181518
%2 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
15191519
^bb0(%arg4: f32, %arg5: f32):
15201520
%48 = arith.mulf %arg4, %arg5 : f32
15211521
tt.reduce.return %48 : f32
15221522
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>
15231523

1524-
// CHECK: @_Z32sub_group_non_uniform_reduce_muli
1524+
// CHECK: @_Z27__spirv_GroupNonUniformIMuliij
15251525
%3 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
15261526
^bb0(%arg4: i32, %arg5: i32):
15271527
%48 = arith.muli %arg4, %arg5 : i32
15281528
tt.reduce.return %48 : i32
15291529
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>
15301530

1531-
// CHECK: @_Z32sub_group_non_uniform_reduce_maxf
1531+
// CHECK: @_Z27__spirv_GroupNonUniformFMaxiif
15321532
%4 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
15331533
^bb0(%arg4: f32, %arg5: f32):
15341534
%48 = arith.maxnumf %arg4, %arg5 : f32
15351535
tt.reduce.return %48 : f32
15361536
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>
15371537

1538-
// CHECK: @_Z32sub_group_non_uniform_reduce_minf
1538+
// CHECK: @_Z27__spirv_GroupNonUniformFMiniif
15391539
%5 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
15401540
^bb0(%arg4: f32, %arg5: f32):
15411541
%48 = arith.minnumf %arg4, %arg5 : f32
15421542
tt.reduce.return %48 : f32
15431543
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>
15441544

1545-
// CHECK: @_Z32sub_group_non_uniform_reduce_andi
1545+
// CHECK: @_Z33__spirv_GroupNonUniformBitwiseAndiij
15461546
%6 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
15471547
^bb0(%arg4: i32, %arg5: i32):
15481548
%48 = arith.andi %arg4, %arg5 : i32
15491549
tt.reduce.return %48 : i32
15501550
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>
15511551

1552-
// CHECK: @_Z31sub_group_non_uniform_reduce_ori
1552+
// CHECK: @_Z32__spirv_GroupNonUniformBitwiseOriij
15531553
%7 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
15541554
^bb0(%arg4: i32, %arg5: i32):
15551555
%48 = arith.ori %arg4, %arg5 : i32
15561556
tt.reduce.return %48 : i32
15571557
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>
15581558

1559-
// CHECK: @_Z32sub_group_non_uniform_reduce_xori
1559+
// CHECK: @_Z33__spirv_GroupNonUniformBitwiseXoriij
15601560
%8 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
15611561
^bb0(%arg4: i32, %arg5: i32):
15621562
%48 = arith.xori %arg4, %arg5 : i32
@@ -1575,63 +1575,63 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
15751575
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
15761576
tt.func public @reduce_cluster(%arg: tensor<256x1xi32, #blocked>, %arg_0: tensor<256x1xf32, #blocked>) {
15771577

1578-
// CHECK: @_Z30sub_group_clustered_reduce_addfj
1578+
// CHECK: @_Z27__spirv_GroupNonUniformFAddiif
15791579
%0 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
15801580
^bb0(%arg4: f32, %arg5: f32):
15811581
%48 = arith.addf %arg4, %arg5 : f32
15821582
tt.reduce.return %48 : f32
15831583
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>
15841584

1585-
// CHECK: @_Z30sub_group_clustered_reduce_addij
1585+
// CHECK: @_Z27__spirv_GroupNonUniformIAddiij
15861586
%1 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
15871587
^bb0(%arg4: i32, %arg5: i32):
15881588
%48 = arith.addi %arg4, %arg5 : i32
15891589
tt.reduce.return %48 : i32
15901590
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>
15911591

1592-
// CHECK: @_Z30sub_group_clustered_reduce_mulfj
1592+
// CHECK: @_Z27__spirv_GroupNonUniformFMuliif
15931593
%2 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
15941594
^bb0(%arg4: f32, %arg5: f32):
15951595
%48 = arith.mulf %arg4, %arg5 : f32
15961596
tt.reduce.return %48 : f32
15971597
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>
15981598

1599-
// CHECK: @_Z30sub_group_clustered_reduce_mulij
1599+
// CHECK: @_Z27__spirv_GroupNonUniformIMuliij
16001600
%3 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
16011601
^bb0(%arg4: i32, %arg5: i32):
16021602
%48 = arith.muli %arg4, %arg5 : i32
16031603
tt.reduce.return %48 : i32
16041604
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>
16051605

1606-
// CHECK: @_Z30sub_group_clustered_reduce_maxfj
1606+
// CHECK: @_Z27__spirv_GroupNonUniformFMaxiif
16071607
%4 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
16081608
^bb0(%arg4: f32, %arg5: f32):
16091609
%48 = arith.maxnumf %arg4, %arg5 : f32
16101610
tt.reduce.return %48 : f32
16111611
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>
16121612

1613-
// CHECK: @_Z30sub_group_clustered_reduce_minfj
1613+
// CHECK: @_Z27__spirv_GroupNonUniformFMiniif
16141614
%5 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
16151615
^bb0(%arg4: f32, %arg5: f32):
16161616
%48 = arith.minnumf %arg4, %arg5 : f32
16171617
tt.reduce.return %48 : f32
16181618
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>
16191619

1620-
// CHECK: @_Z30sub_group_clustered_reduce_andij
1620+
// CHECK: @_Z33__spirv_GroupNonUniformBitwiseAndiij
16211621
%6 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
16221622
^bb0(%arg4: i32, %arg5: i32):
16231623
%48 = arith.andi %arg4, %arg5 : i32
16241624
tt.reduce.return %48 : i32
16251625
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>
16261626

1627-
// CHECK: @_Z29sub_group_clustered_reduce_orij
1627+
// CHECK: @_Z32__spirv_GroupNonUniformBitwiseOriij
16281628
%7 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
16291629
^bb0(%arg4: i32, %arg5: i32):
16301630
%48 = arith.ori %arg4, %arg5 : i32
16311631
tt.reduce.return %48 : i32
16321632
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>
16331633

1634-
// CHECK: @_Z30sub_group_clustered_reduce_xorij
1634+
// CHECK: @_Z33__spirv_GroupNonUniformBitwiseXoriij
16351635
%8 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
16361636
^bb0(%arg4: i32, %arg5: i32):
16371637
%48 = arith.xori %arg4, %arg5 : i32
@@ -1645,9 +1645,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
16451645
// -----
16461646

16471647
// CHECK-LABEL: sum_reduction
1648-
// CHECK: llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_addi(%{{.*}}) {{.*}} : (i32) -> i32
1648+
// CHECK: llvm.call spir_funccc @_Z27__spirv_GroupNonUniformIAddiij(%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, i32, i32) -> i32
16491649
// CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> ()
1650-
// CHECK: llvm.call spir_funccc @_Z30sub_group_clustered_reduce_addij(%{{.*}}, %{{.*}}) {{.*}}convergent{{.*}}no_unwind{{.*}}will_return{{.*}} : (i32, i32) -> i32
1650+
// CHECK: llvm.call spir_funccc @_Z27__spirv_GroupNonUniformIAddiijj(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, i32, i32, i32) -> i32
16511651

16521652
// CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> ()
16531653
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>

test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,12 @@ module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.sup
153153
#warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}>
154154
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
155155

156-
// CHECK-DAG: llvm.func spir_funccc @_Z32sub_group_non_uniform_reduce_maxf(f32) -> f32
157-
// CHECK-DAG: llvm.func spir_funccc @_Z32sub_group_non_uniform_reduce_addf(f32) -> f32
158-
159156
// CHECK-LABEL: llvm.func spir_kernelcc @reduce_sum(
160157
// CHECK-SAME: [[VAL_0:%.*]]: vector<8xf32>) -> f32 attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>}
161158
tt.func public @reduce_sum(%arg0: tensor<8x16xf32>) -> f32 {
162159
// CHECK: [[VAL_1:%.*]] = llvm.mlir.constant(0 : i32) : i32
163160
// CHECK: [[VAL_2:%.*]] = llvm.extractelement [[VAL_0]][[[VAL_1]] : i32] : vector<8xf32>
164-
// CHECK: [[VAL_3:%.*]] = llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_addf([[VAL_2]]) {{.*}} : (f32) -> f32
161+
// CHECK: [[VAL_3:%.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFAddiif(%{{.*}}, %{{.*}}, [[VAL_2]]) {{.*}} : (i32, i32, f32) -> f32
165162
%0 = triton_intel_gpu.extract %arg0[0] : tensor<8x16xf32> -> tensor<16xf32>
166163
%1 = "tt.reduce"(%0) <{axis = 0 : i32}> ({
167164
^bb0(%arg1: f32, %arg2: f32):
@@ -176,7 +173,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
176173
tt.func public @reduce_max(%arg0: tensor<8x16xf32>) -> f32 {
177174
// CHECK: [[VAL_1:%.*]] = llvm.mlir.constant(0 : i32) : i32
178175
// CHECK: [[VAL_2:%.*]] = llvm.extractelement [[VAL_0]][[[VAL_1]] : i32] : vector<8xf32>
179-
// CHECK: [[VAL_3:%.*]] = llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_maxf([[VAL_2]]) {{.*}} : (f32) -> f32
176+
// CHECK: [[VAL_3:%.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFMaxiif(%{{.*}}, %{{.*}}, [[VAL_2]]) {{.*}} : (i32, i32, f32) -> f32
180177
%0 = triton_intel_gpu.extract %arg0[0] : tensor<8x16xf32> -> tensor<16xf32>
181178
%1 = "tt.reduce"(%0) <{axis = 0 : i32}> ({
182179
^bb0(%arg1: f32, %arg2: f32):

test/TritonGEN/gpu-to-tritongen.mlir

Lines changed: 0 additions & 40 deletions
This file was deleted.

test/TritonGEN/tritongen-invalid.mlir

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,54 +16,6 @@ llvm.func @triton_gen.illegal_cache_controls_attr(%arg0: !llvm.ptr) {
1616

1717
// -----
1818

19-
llvm.func @triton_gen.sub_group_reduce() {
20-
// expected-error @+2 {{'triton_gen.sub_group_reduce' op expecting valid target env attribute}}
21-
%0 = llvm.mlir.constant(0 : i32) : i32
22-
%1 = triton_gen.sub_group_reduce add %0 {size = 16} : i32
23-
llvm.return
24-
}
25-
26-
// -----
27-
28-
module attributes {
29-
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Kernel, Addresses, GroupNonUniformShuffle, Int64], []>, #spirv.resource_limits<subgroup_size = 16>>
30-
} {
31-
llvm.func @triton_gen.sub_group_reduce() {
32-
// expected-error @+2 {{'triton_gen.sub_group_reduce' op expecting size to be a power of 2 between 1 and subgroup size}}
33-
%0 = llvm.mlir.constant(0 : i32) : i32
34-
%1 = triton_gen.sub_group_reduce add %0 {size = 0} : i32
35-
llvm.return
36-
}
37-
}
38-
39-
// -----
40-
41-
module attributes {
42-
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Kernel, Addresses, GroupNonUniformShuffle, Int64], []>, #spirv.resource_limits<subgroup_size = 16>>
43-
} {
44-
llvm.func @triton_gen.sub_group_reduce() {
45-
// expected-error @+2 {{'triton_gen.sub_group_reduce' op expecting size to be a power of 2 between 1 and subgroup size}}
46-
%0 = llvm.mlir.constant(0 : i32) : i32
47-
%1 = triton_gen.sub_group_reduce add %0 {size = 32} : i32
48-
llvm.return
49-
}
50-
}
51-
52-
// -----
53-
54-
module attributes {
55-
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Kernel, Addresses, GroupNonUniformShuffle, Int64], []>, #spirv.resource_limits<subgroup_size = 16>>
56-
} {
57-
llvm.func @triton_gen.sub_group_reduce() {
58-
// expected-error @+2 {{'triton_gen.sub_group_reduce' op expecting size to be a power of 2 between 1 and subgroup size}}
59-
%0 = llvm.mlir.constant(0 : i32) : i32
60-
%1 = triton_gen.sub_group_reduce add %0 {size = 6} : i32
61-
llvm.return
62-
}
63-
}
64-
65-
// -----
66-
6719
llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) {
6820
// expected-error @+1 {{'triton_gen.dpas' op expecting repeat count to be 1, 2, 4, or 8}}
6921
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=16} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>

0 commit comments

Comments
 (0)