Skip to content

Commit b19c43a

Browse files
Fix build and test failures from 68a24ff
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 8065e6a commit b19c43a

File tree

6 files changed

+28
-25
lines changed

6 files changed

+28
-25
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2993,7 +2993,7 @@ struct TritonGPUVerifyTensorLayoutInterface
29932993
<< rankedTy.getShape()
29942994
<< " which is not a power of two.";
29952995
}
2996-
auto ll = toLinearLayout(rankedTy.getShape(), layout);
2996+
auto ll = toLinearLayout(rankedTy);
29972997
ModuleOp module = op->getParentOfType<ModuleOp>();
29982998

29992999
// Number of threads per warp.

python/test/unit/intel/test_block_load.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,24 @@ def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pa
2121
A_width = 2
2222
B_width = 4
2323
layouts = "#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 2]}>"
24+
num_warps = 4
2425
elif dtype_str == "float32":
2526
A_width = 1
2627
B_width = 1
2728
layouts = "#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2]}>"
29+
num_warps = 32
2830
else:
2931
A_width = 1
3032
B_width = 2
3133
layouts = "#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2]}>"
34+
num_warps = 32
3235

3336
block_io = "\"column_major\"" if transpose else "\"row_major\""
3437

3538
ty = {"float32": "f32", "float16": "f16", "int8": "i8"}[dtype_str]
3639

3740
ir = layouts + f"""
38-
module attributes {{ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32}} {{
41+
module attributes {{ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32}} {{
3942
tt.func public @block_load_dpas_layout(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}, %arg3: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}) attributes {{noinline = false}} {{
4043
%0 = tt.get_program_id x : i32
4144
%M_i64 = arith.constant {M} : i64

python/test/unit/language/test_core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6704,13 +6704,17 @@ def test_local_load_store_dot(M, N, dtype, dist_layout, shared_layout, device, t
67046704
elif dtype == "float8e5":
67056705
mlir_dtype = "f8E5M2"
67066706

6707+
num_warps = 4
6708+
if isinstance(dist_layout, DotOperandLayout) and isinstance(dist_layout.parent, DpasLayout):
6709+
num_warps = math.prod(dist_layout.parent.warps_per_cta)
6710+
67076711
layouts = f"""
67086712
#dist = {dist_layout}
67096713
#shared = {shared_layout}
67106714
#smem = #ttg.shared_memory
67116715
"""
67126716
ir = layouts + f"""
6713-
module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}} {{
6717+
module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, "ttg.threads-per-warp" = 32 : i32}} {{
67146718
tt.func public @kernel(%arg0: !tt.ptr<{mlir_dtype}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{mlir_dtype}> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
67156719
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #dist>
67166720
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>

test/Conversion/intel/dot_layout_offset.mlir

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#dpas = #ttig.dpas<{repeatCount=8, systolicDepth=8, executionSize = 8, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA=[1, 1], repCluster=[2, 2]}>
44
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
5-
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} {
5+
module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} {
66
// CHECK-LABEL: llvm.func spir_kernelcc @dot_layout_emit_offset(%arg0: !llvm.ptr<1>)
77
tt.func public @dot_layout_emit_offset() {
88
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #dot_operand_a>
@@ -11,12 +11,10 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32}
1111
// COM: Base index of the dot layout.
1212
// CHECK: %[[THREAD_ID_I64:.*]] = llvm.call spir_funccc @_Z12get_local_idj
1313
// CHECK: %[[THREAD_ID_I32:.*]] = llvm.trunc %[[THREAD_ID_I64]] : i64 to i32
14-
// CHECK: %[[CST_63:.*]] = llvm.mlir.constant(63 : i32) : i32
15-
// CHECK: %[[RTID:.*]] = llvm.and %[[THREAD_ID_I32]], %[[CST_63]] : i32
14+
// CHECK: %[[CST_63:.*]] = llvm.mlir.constant(15 : i32) : i32
15+
// CHECK: %[[LANE_ID:.*]] = llvm.and %[[THREAD_ID_I32]], %[[CST_63]] : i32
1616
// CHECK: %[[VAL_145:.*]] = llvm.mlir.constant(16 : i32) : i32
17-
// CHECK: %[[LANE_ID:.*]] = llvm.urem %[[RTID]], %[[VAL_145]] : i32
18-
// CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[RTID]], %[[VAL_145]] : i32
19-
// CHECK-COUNT-4: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
17+
// CHECK-COUNT-5: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
2018
// CHECK: %[[VAL_149:.*]] = llvm.mlir.constant(1 : i32) : i32
2119
// CHECK: %[[VAL_150:.*]] = llvm.and %[[LANE_ID]], %[[VAL_149]] : i32
2220
// CHECK: %[[VAL_151:.*]] = llvm.icmp "eq" %[[VAL_150]], %[[CST_0]] : i32
@@ -324,7 +322,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32}
324322

325323
#dpas = #ttig.dpas<{repeatCount=8, systolicDepth=8, executionSize = 8, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA=[1, 1], repCluster=[2, 2]}>
326324
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
327-
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} {
325+
module attributes {"ttg.num-warps" = 1 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} {
328326

329327
// CHECK-LABEL: llvm.func spir_kernelcc @dot_layout_emit_offset(%arg0: !llvm.ptr<1>)
330328
tt.func public @dot_layout_emit_offset() {
@@ -335,12 +333,10 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.thr
335333
// COM: Base index of the dot layout.
336334
// CHECK: %[[THREAD_ID_I64:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[VAL_142]])
337335
// CHECK: %[[THREAD_ID_I32:.*]] = llvm.trunc %[[THREAD_ID_I64]] : i64 to i32
338-
// CHECK-DAG: %[[CST_63:.*]] = llvm.mlir.constant(63 : i32) : i32
339-
// CHECK-DAG: %[[RTID:.*]] = llvm.and %[[THREAD_ID_32:.*]], %[[CST_63]] : i32
336+
// CHECK-DAG: %[[CST_63:.*]] = llvm.mlir.constant(15 : i32) : i32
337+
// CHECK-DAG: %[[LANE_ID:.*]] = llvm.and %[[THREAD_ID_32:.*]], %[[CST_63]] : i32
340338
// CHECK-DAG: %[[VAL_145:.*]] = llvm.mlir.constant(16 : i32) : i32
341-
// CHECK-DAG: %[[LANE_ID:.*]] = llvm.urem %[[RTID]], %[[VAL_145]] : i32
342-
// CHECK-DAG: %[[WARP_ID:.*]] = llvm.udiv %[[RTID]], %[[VAL_145]] : i32
343-
// CHECK-COUNT-4: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
339+
// CHECK-COUNT-5: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
344340
// CHECK: %[[VAL_149:.*]] = llvm.mlir.constant(1 : i32) : i32
345341
// CHECK: %[[VAL_150:.*]] = llvm.and %[[LANE_ID]], %[[VAL_149]] : i32
346342
// CHECK: %[[VAL_151:.*]] = llvm.icmp "eq" %[[VAL_150]], %[[CST_0]] : i32

test/TritonIntelGPU/materialize-block-pointer.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
44
#dot_a = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>
55
#dot_b = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>
6-
module attributes {"ttg.num-ctas" = 1 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {
6+
module attributes {"ttg.num-ctas" = 1 : i32, ttg.target = "xpu", "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {
77
// CHECK-LABEL: tt.func public @materialize_block_pointer(
88
tt.func public @materialize_block_pointer(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 15 : i32}, %pitch: i64 {tt.divisibility = 16 : i32}, %pitch_odd: i64 {tt.divisibility = 15 : i32}) {
99
%c0_i32 = arith.constant 0 : i32
@@ -192,7 +192,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
192192

193193
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
194194
#dot_a = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>
195-
module attributes {"ttg.num-ctas" = 1 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {
195+
module attributes {"ttg.num-ctas" = 1 : i32, ttg.target = "xpu", "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {
196196
// CHECK-LABEL: tt.func public @materialize_block_pointer(
197197
tt.func public @materialize_block_pointer(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %pitch: i64 {tt.divisibility = 16 : i32}) {
198198
%c0_i32 = arith.constant 0 : i32

test/TritonIntelGPU/tensor-pointer-store-block-2d.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
5454

5555
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [4, 4], repCluster = [2, 2]}>
5656
#dot_a = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>
57-
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
57+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32, "ttg.threads-per-warp" = 16 : i32} {
5858
// CHECK-LABEL: @regular_pointer_block_io
5959
tt.func public @regular_pointer_block_io(%arg0: !tt.ptr<i8>) {
6060
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #dot_a}>>
@@ -69,7 +69,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
6969
%9 = tt.splat %arg0 : !tt.ptr<i8> -> tensor<256x64x!tt.ptr<i8>, #dot_a>
7070
%addr = tt.addptr %9, %8 : tensor<256x64x!tt.ptr<i8>, #dot_a>, tensor<256x64xi32, #dot_a>
7171
%cst = arith.constant dense<0> : tensor<256x64xi8, #dot_a>
72-
// CHECK-COUNT-32: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
72+
// CHECK-COUNT-16: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
7373
tt.store %addr, %cst {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<i8>, #dot_a>
7474

7575
tt.return
@@ -80,7 +80,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
8080

8181
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 4], repCluster = [2, 2]}>
8282
#dot_a = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>
83-
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
83+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32, "ttg.threads-per-warp" = 16 : i32} {
8484
// CHECK-LABEL: @regular_pointer_block_io
8585
tt.func public @regular_pointer_block_io(%arg0: !tt.ptr<f32>) {
8686
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #dot_a}>>
@@ -95,7 +95,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
9595
%9 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x64x!tt.ptr<f32>, #dot_a>
9696
%addr = tt.addptr %9, %8 : tensor<256x64x!tt.ptr<f32>, #dot_a>, tensor<256x64xi32, #dot_a>
9797
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #dot_a>
98-
// CHECK-COUNT-128: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 8, v_blocks = 1, cache_control = Default}
98+
// CHECK-COUNT-64: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 8, v_blocks = 1, cache_control = Default}
9999
tt.store %addr, %cst {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f32>, #dot_a>
100100

101101
tt.return
@@ -106,7 +106,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
106106

107107
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 4], repCluster = [2, 2]}>
108108
#dot_b = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth = 1}>
109-
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
109+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32, "ttg.threads-per-warp" = 16 : i32} {
110110
// CHECK-LABEL: @regular_pointer_block_io
111111
tt.func public @regular_pointer_block_io(%arg0: !tt.ptr<f32>) {
112112
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #dot_b}>>
@@ -121,7 +121,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
121121
%9 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x64x!tt.ptr<f32>, #dot_b>
122122
%addr = tt.addptr %9, %8 : tensor<256x64x!tt.ptr<f32>, #dot_b>, tensor<256x64xi32, #dot_b>
123123
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #dot_b>
124-
// CHECK-COUNT-128: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
124+
// CHECK-COUNT-64: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
125125
tt.store %addr, %cst {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f32>, #dot_b>
126126

127127
tt.return
@@ -131,7 +131,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
131131
// -----
132132

133133
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 4], repCluster = [2, 2]}>
134-
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
134+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32, "ttg.threads-per-warp" = 16 : i32} {
135135
// CHECK-LABEL: @regular_pointer_block_io
136136
tt.func public @regular_pointer_block_io(%arg0: !tt.ptr<f32>) {
137137
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #dpas}>>
@@ -146,7 +146,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
146146
%9 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x64x!tt.ptr<f32>, #dpas>
147147
%addr = tt.addptr %9, %8 : tensor<256x64x!tt.ptr<f32>, #dpas>, tensor<256x64xi32, #dpas>
148148
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #dpas>
149-
// CHECK-COUNT-32: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
149+
// CHECK-COUNT-16: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
150150
tt.store %addr, %cst {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f32>, #dpas>
151151

152152
tt.return

0 commit comments

Comments
 (0)