Skip to content

Commit f0a2fcc

Browse files
committed
Add back attribute
1 parent a8dcad0 commit f0a2fcc

File tree

12 files changed

+85
-33
lines changed

12 files changed

+85
-33
lines changed

test/Conversion/intel/dpas_to_block_layout_convert.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 67584 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
77
// CHECK-LABEL: llvm.func spir_kernelcc @convert_dpas(
88
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<1>,
9-
// CHECK-SAME: %[[SCRATCH_SLM:.*]]: !llvm.ptr<3>) attributes {intel_reqd_sub_group_size = 16 : i32, noinline = false, reqd_work_group_size = array<i32: 512, 1, 1>} {
9+
// CHECK-SAME: %[[SCRATCH_SLM:.*]]: !llvm.ptr<3>) attributes {intel_reqd_sub_group_size = 16 : i32, noinline = false, triton_gen.max_work_group_size = array<i32: 512, 1, 1>} {
1010
tt.func public @convert_dpas(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
1111
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #mma>
1212

@@ -69,7 +69,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32
6969
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 67584 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
7070
// CHECK-LABEL: llvm.func spir_kernelcc @convert_dpas(
7171
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<1>,
72-
// CHECK-SAME: %[[SCRATCH_SLM:.*]]: !llvm.ptr<3>) attributes {intel_reqd_sub_group_size = 16 : i32, noinline = false, reqd_work_group_size = array<i32: 512, 1, 1>} {
72+
// CHECK-SAME: %[[SCRATCH_SLM:.*]]: !llvm.ptr<3>) attributes {intel_reqd_sub_group_size = 16 : i32, noinline = false, triton_gen.max_work_group_size = array<i32: 512, 1, 1>} {
7373
tt.func public @convert_dpas(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
7474
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #mma>
7575

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
44
// CHECK: llvm.func spir_kernelcc @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>)
55
// Here the 128 comes from the 4 in module attribute multiples 32
6-
// CHECK-SAME: attributes {intel_reqd_sub_group_size = 32 : i32, reqd_work_group_size = array<i32: 128, 1, 1>} {
6+
// CHECK-SAME: attributes {intel_reqd_sub_group_size = 32 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>} {
77
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
88
// CHECK: llvm.return
99
tt.return

test/Conversion/intel/tritongpu_to_gen_dot.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
106106
// CHECK: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
107107
// CHECK-LABEL: llvm.func spir_kernelcc @dot_rep_cluster_4_2(
108108
// CHECK-SAME: %[[A:.*]]: !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>, %[[B:.*]]: !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>,
109-
// CHECK-SAME: %[[C:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 16, 1, 1>} {
109+
// CHECK-SAME: %[[C:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 16, 1, 1>} {
110110
tt.func @dot_rep_cluster_4_2(%a: tensor<32x32xf16, #dot_operand_a>, %b: tensor<32x32xf16, #dot_operand_b>, %c: tensor<32x32xf32, #dpas>) {
111111
// CHECK: %[[VAL_3:.*]] = llvm.mlir.undef : vector<8xf32>
112112
// CHECK: %[[CST_15:.*]] = llvm.mlir.constant(15 : i32) : i32

test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.sup
114114

115115
module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.support_dpas", "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
116116
// CHECK-LABEL: llvm.func spir_kernelcc @matmul_kernel_with_block_pointers_tf32(
117-
// CHECK-SAME: [[VAL_0:%.*]]: !llvm.ptr<1>) attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 512, 1, 1>} {
117+
// CHECK-SAME: [[VAL_0:%.*]]: !llvm.ptr<1>) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 512, 1, 1>} {
118118
tt.func public @matmul_kernel_with_block_pointers_tf32(%arg0: !tt.ptr<f32>) {
119119
%c0_i64 = arith.constant 0 : i64
120120
%c0_i32 = arith.constant 0 : i32
@@ -134,7 +134,7 @@ module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.sup
134134

135135
module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.support_dpas", "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
136136
// CHECK-LABEL: llvm.func spir_kernelcc @matmul_kernel_with_block_pointers_f16accu(
137-
// CHECK-SAME: [[VAL_0:%.*]]: !llvm.ptr<1>) attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 512, 1, 1>} {
137+
// CHECK-SAME: [[VAL_0:%.*]]: !llvm.ptr<1>) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 512, 1, 1>} {
138138
tt.func public @matmul_kernel_with_block_pointers_f16accu(%arg0: !tt.ptr<f16>) {
139139
%c0_i64 = arith.constant 0 : i64
140140
%c0_i32 = arith.constant 0 : i32
@@ -157,7 +157,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
157157
// CHECK-DAG: llvm.func spir_funccc @_Z32sub_group_non_uniform_reduce_addf(f32) -> f32
158158

159159
// CHECK-LABEL: llvm.func spir_kernelcc @reduce_sum(
160-
// CHECK-SAME: [[VAL_0:%.*]]: vector<8xf32>) -> f32 attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 128, 1, 1>}
160+
// CHECK-SAME: [[VAL_0:%.*]]: vector<8xf32>) -> f32 attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>}
161161
tt.func public @reduce_sum(%arg0: tensor<8x16xf32>) -> f32 {
162162
// CHECK: [[VAL_1:%.*]] = llvm.mlir.constant(0 : i32) : i32
163163
// CHECK: [[VAL_2:%.*]] = llvm.extractelement [[VAL_0]][[[VAL_1]] : i32] : vector<8xf32>
@@ -172,7 +172,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
172172
}
173173

174174
// CHECK-LABEL: llvm.func spir_kernelcc @reduce_max(
175-
// CHECK-SAME: [[VAL_0:%.*]]: vector<8xf32>) -> f32 attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 128, 1, 1>}
175+
// CHECK-SAME: [[VAL_0:%.*]]: vector<8xf32>) -> f32 attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>}
176176
tt.func public @reduce_max(%arg0: tensor<8x16xf32>) -> f32 {
177177
// CHECK: [[VAL_1:%.*]] = llvm.mlir.constant(0 : i32) : i32
178178
// CHECK: [[VAL_2:%.*]] = llvm.extractelement [[VAL_0]][[[VAL_1]] : i32] : vector<8xf32>
@@ -229,7 +229,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
229229
}
230230

231231
// CHECK-LABEL: llvm.func spir_kernelcc @addptr(
232-
// CHECK-SAME: [[VAL_0:%.*]]: !llvm.ptr<1>) -> !llvm.ptr<1> attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 128, 1, 1>}
232+
// CHECK-SAME: [[VAL_0:%.*]]: !llvm.ptr<1>) -> !llvm.ptr<1> attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>}
233233
tt.func public @addptr(%arg0: !tt.ptr<f16>) -> !tt.ptr<f16> {
234234
// CHECK: [[VAL_1:%.*]] = llvm.mlir.constant(0 : i32) : i32
235235
// CHECK: [[VAL_2:%.*]] = llvm.call spir_funccc @_Z12get_group_idj([[VAL_1]]) {{.*}} : (i32) -> i64
@@ -368,7 +368,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
368368
#warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}>
369369

370370
// CHECK-LABEL: llvm.func spir_kernelcc @test(
371-
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<16xf32> attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 64, 1, 1>} {
371+
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<16xf32> attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 64, 1, 1>} {
372372
// CHECK: %[[VAL_2:.*]] = llvm.mlir.poison : vector<16xf32>
373373
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i32) : i32
374374
// CHECK: %[[VAL_4:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_0]], %[[VAL_3]])

test/Conversion/intel/tritongpu_transposed_reduction.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
1818
// CHECK: }
1919

2020
// CHECK: llvm.func spir_kernelcc @reduce_sum(
21-
// CHECK-SAME: %[[VAL_0:.*]]: vector<16xf32>) -> f32 attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 128, 1, 1>} {
21+
// CHECK-SAME: %[[VAL_0:.*]]: vector<16xf32>) -> f32 attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>} {
2222
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
2323
// CHECK: %[[VAL_3:.*]] = llvm.extractelement %[[VAL_0]]{{\[}}%[[VAL_2]] : i32] : vector<16xf32>
2424
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -78,7 +78,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
7878
}
7979

8080
// CHECK: llvm.func spir_kernelcc @reduce_max(
81-
// CHECK-SAME: %[[VAL_0:.*]]: vector<16xf32>) -> f32 attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 128, 1, 1>} {
81+
// CHECK-SAME: %[[VAL_0:.*]]: vector<16xf32>) -> f32 attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>} {
8282
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
8383
// CHECK: %[[VAL_3:.*]] = llvm.extractelement %[[VAL_0]]{{\[}}%[[VAL_2]] : i32] : vector<16xf32>
8484
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32

test/Target/LLVMIR/triton-gen.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
// RUN: triton-translate -triton-to-llvmir -split-input-file %s | FileCheck %s
22

3+
// CHECK: define spir_kernel void @test_max_work_group_size() !max_work_group_size ![[MAX_WORK_GROUP_SIZE:.*]] {
4+
llvm.func spir_kernelcc @test_max_work_group_size() attributes {triton_gen.max_work_group_size = [128 : i32, 1 : i32, 1 : i32]} {
5+
llvm.return
6+
}
7+
8+
// -----
9+
310
llvm.func @foo(%arg0: !llvm.ptr, %arg1: !llvm.ptr)
411

512
// CHECK-LABEL: define void @triton_gen.cache_controls(

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
5757
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
5858
// CHECK-LABEL: llvm.func spir_kernelcc @dot_op_a_2d_load(
5959
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<1>,
60-
// CHECK-SAME: %[[VAL_1:.*]]: i64, %[[VAL_2:.*]]: i64, %[[VAL_3:.*]]: i64, %[[VAL_4:.*]]: i64) attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 128, 1, 1>} {
60+
// CHECK-SAME: %[[VAL_1:.*]]: i64, %[[VAL_2:.*]]: i64, %[[VAL_3:.*]]: i64, %[[VAL_4:.*]]: i64) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>} {
6161
tt.func public @dot_op_a_2d_load(%arg0: !tt.ptr<f16>, %arg2: i64, %arg4: i64, %arg5: i64, %arg7: i64) {
6262
%c0_i32 = arith.constant 0 : i32
6363
%c1_i64 = arith.constant 1 : i64
@@ -129,7 +129,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
129129
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
130130
// CHECK-LABEL: llvm.func spir_kernelcc @dot_op_b_2d_load(
131131
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<1>,
132-
// CHECK-SAME: %[[VAL_1:.*]]: i64, %[[VAL_2:.*]]: i64, %[[VAL_3:.*]]: i64) attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 128, 1, 1>} {
132+
// CHECK-SAME: %[[VAL_1:.*]]: i64, %[[VAL_2:.*]]: i64, %[[VAL_3:.*]]: i64) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>} {
133133
tt.func public @dot_op_b_2d_load(%arg1: !tt.ptr<f16>, %arg3: i64, %arg4: i64, %arg7: i64) {
134134
%c0_i32 = arith.constant 0 : i32
135135
%c1_i64 = arith.constant 1 : i64

test/TritonIntelGPU/blockptr_store.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
6161
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
6262
// CHECK-LABEL: llvm.func spir_kernelcc @dpas_layout_2d_store_rep_cluster_4_2(
6363
// CHECK-SAME: %[[base:.*]]: !llvm.ptr<1>,
64-
// CHECK-SAME: %[[width:.*]]: i64, %[[height:.*]]: i64, %[[rowStride:.*]]: i64) attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 128, 1, 1>} {
64+
// CHECK-SAME: %[[width:.*]]: i64, %[[height:.*]]: i64, %[[rowStride:.*]]: i64) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>} {
6565
tt.func public @dpas_layout_2d_store_rep_cluster_4_2(%base: !tt.ptr<f16>, %width: i64, %height: i64, %rowStride: i64) {
6666
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #dpas>
6767
%c0_i32 = arith.constant 0 : i32

test/TritonIntelGPU/tritonintelgpu-convert-layout-shortcut.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [32, 1], repCluster = [1, 2], A = [8, 16], B = [16, 32], C = [8, 32]}>
44
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
55
// CHECK-LABEL: convert_dpas_to_dot_rep_cluster_1_2
6-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<({{.*}})>) attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 512, 1, 1>} {
6+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<({{.*}})>) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 512, 1, 1>} {
77
tt.func public @convert_dpas_to_dot_rep_cluster_1_2(%arg: tensor<1024x32xf16, #dpas>) {
88
// COM: The repetitions order of dot layout and dpas layout are same when the GEMM tiling is clustered as repCluster [1, 2].
99
// CHECK: %[[VAL_81:.*]] = llvm.mlir.constant(7 : i32) : i32
@@ -56,7 +56,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32
5656
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [32, 1], repCluster = [2, 2], A = [8, 16], B = [16, 32], C = [8, 32]}>
5757
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
5858
// CHECK-LABEL: convert_dpas_to_dot_rep_cluster_2_2
59-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<({{.*}})>) attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 512, 1, 1>} {
59+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<({{.*}})>) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 512, 1, 1>} {
6060
tt.func public @convert_dpas_to_dot_rep_cluster_2_2(%arg: tensor<1024x32xf16, #dpas>) {
6161
// COM: The repetitions order of dpas layout when the GEMM tiling is clustered as repCluster [2, 2]:
6262
// COM: - 0, 1, 2, 3, 4, 5, 6, 7.
@@ -112,7 +112,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32
112112
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [32, 1], repCluster = [4, 2], A = [8, 16], B = [16, 32], C = [8, 32]}>
113113
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
114114
// CHECK-LABEL: convert_dpas_to_dot_rep_cluster_4_2
115-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<({{.*}})>) attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 512, 1, 1>} {
115+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<({{.*}})>) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 512, 1, 1>} {
116116
tt.func public @convert_dpas_to_dot_rep_cluster_4_2(%arg: tensor<1024x32xf16, #dpas>) {
117117
// COM: The repetitions order of dpas layout when the GEMM tiling is clustered as repCluster [4, 2]:
118118
// COM: - 0, 1, 2, 3, 4, 5, 6, 7.

third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ def TritonGEN_Dialect : Dialect {
2424
let dependentDialects = ["mlir::LLVM::LLVMDialect"];
2525

2626
let extraClassDeclaration = [{
27+
/// Get the name of the attribute used to annotate max work group size
28+
/// required for kernels.
29+
static constexpr ::llvm::StringLiteral getMaxWorkGroupSizeAttrName() {
30+
return ::llvm::StringLiteral("triton_gen.max_work_group_size");
31+
}
32+
2733
/// Get the name for the attribute used to specify cache control
2834
/// decorations.
2935
static constexpr ::llvm::StringRef getCacheControlsAttrName() {

0 commit comments

Comments
 (0)