Skip to content

Commit 2823aaa

Browse files
[TEST] Add more block store test cases (#4717)
1 parent 51244af commit 2823aaa

File tree

2 files changed

+116
-52
lines changed

2 files changed

+116
-52
lines changed

python/test/unit/intel/test_block_store.py

Lines changed: 115 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -25,80 +25,145 @@ def __str__(self):
2525
return f"#ttig.dpas<{{repeatCount={self.repeatCount}, systolicDepth={self.systolic_depth}, executionSize = {self.execution_size}, opsPerChan = {self.ops_per_chan}, threadsPerWarp = {self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, repCluster={self.rep_cluster}}}>"
2626

2727

28+
class DotOperandLayout:
29+
30+
def __init__(self, parent, op_idx, k_width):
31+
self.parent = parent
32+
self.op_idx = op_idx
33+
self.k_width = k_width
34+
self.threads_per_warp = parent.threads_per_warp
35+
36+
def __str__(self):
37+
return f"#ttg.dot_op<{{parent={self.parent}, opIdx={self.op_idx}, kWidth={self.k_width}}}>"
38+
39+
40+
class SliceLayout:
41+
42+
def __init__(self, dim, parent):
43+
self.dim = dim
44+
self.parent = parent
45+
self.threads_per_warp = parent.threads_per_warp
46+
47+
def __str__(self):
48+
return f"#ttg.slice<{{dim = {self.dim}, parent = {self.parent}}}>"
49+
50+
51+
class BlockedLayout:
52+
53+
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1],
54+
cta_split_num=[1, 1], cta_order=[0, 1]):
55+
self.sz_per_thread = size_per_thread
56+
self.threads_per_warp = threads_per_warp
57+
self.warps_per_cta = warps_per_cta
58+
self.order = order
59+
self.ctas_per_cga = ctas_per_cga
60+
self.cta_split_num = cta_split_num
61+
self.cta_order = cta_order
62+
63+
def __str__(self):
64+
return f"#ttg.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
65+
66+
2867
def warps_per_cta(layout):
29-
return layout.warps_per_cta
68+
if isinstance(layout, (SliceLayout, DotOperandLayout)):
69+
return warps_per_cta(layout.parent)
70+
else:
71+
return layout.warps_per_cta
3072

3173

3274
layouts = [
33-
# Layout for Xe
75+
BlockedLayout([1, 1], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
76+
# DPAS layout
3477
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=4, threads_per_warp=16,
3578
warps_per_cta=[1, 4], rep_cluster=[1, 2]),
3679
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16,
3780
warps_per_cta=[8, 4], rep_cluster=[4, 2]),
3881
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=1, threads_per_warp=16,
3982
warps_per_cta=[8, 4], rep_cluster=[1, 1]),
83+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
84+
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
85+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=32,
86+
warps_per_cta=[2, 2], rep_cluster=[1, 1]),
87+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=4, threads_per_warp=32,
88+
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
89+
# DotOp A
90+
DotOperandLayout(
91+
parent=DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=32,
92+
warps_per_cta=[2, 2], rep_cluster=[1, 1]), op_idx=0, k_width=1),
93+
DotOperandLayout(
94+
parent=DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=16,
95+
warps_per_cta=[2, 2], rep_cluster=[1, 1]), op_idx=0, k_width=1),
96+
# DotOp B
97+
DotOperandLayout(
98+
parent=DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=1, threads_per_warp=16,
99+
warps_per_cta=[2, 2], rep_cluster=[1, 1]), op_idx=1, k_width=1),
100+
DotOperandLayout(
101+
parent=DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16,
102+
warps_per_cta=[2, 2], rep_cluster=[1, 1]), op_idx=1, k_width=2),
103+
DotOperandLayout(
104+
parent=DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=4, threads_per_warp=16,
105+
warps_per_cta=[2, 2], rep_cluster=[1, 1]), op_idx=1, k_width=4),
106+
DotOperandLayout(
107+
parent=DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=1, threads_per_warp=32,
108+
warps_per_cta=[2, 2], rep_cluster=[1, 1]), op_idx=1, k_width=1),
109+
# Slice layout
110+
SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [2, 1, 16], [2, 1, 2], [2, 1, 0], [1, 1, 1], [1, 1, 1],
111+
[0, 1, 2])),
40112
]
41113

42114

43-
@pytest.mark.parametrize("M, N", [[M, N] for M, N in itertools.product([32, 64, 128, 256], [32, 64, 128, 256])])
115+
@pytest.mark.parametrize("M, N", [[M, N] for M, N in itertools.product([32, 64, 128], [64, 128])])
44116
@pytest.mark.parametrize("dtype_str", ["float32", "float16", "int8"])
45117
@pytest.mark.parametrize("layout", layouts)
118+
@pytest.mark.parametrize("block_ptr", [True, False])
46119
@pytest.mark.skipif(not is_xpu(), reason="Block store tests are specific to the XPU backend")
47-
def test_tensor_pointer_block_store(M, N, dtype_str, layout, device, tmp_path: pathlib.Path):
120+
def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathlib.Path):
48121

49122
warps = warps_per_cta(layout)
50123
num_warps = int(np.prod(warps))
51124
threads_per_warp = layout.threads_per_warp
52-
ops_per_chan = layout.ops_per_chan
53-
A_width = 1 if ops_per_chan == 1 else ops_per_chan // 2
54-
B_width = ops_per_chan
125+
threads_per_warp = int(np.prod(threads_per_warp))
55126

56127
ty = {"float32": "f32", "float16": "f16", "bfloat16": "i16", "int8": "i8"}[dtype_str]
57128

58129
support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
59130

131+
if block_ptr:
132+
store_ops = f"""
133+
%M_i64 = arith.constant {M} : i64
134+
%N_i64 = arith.constant {N} : i64
135+
%c1_i64 = arith.constant 1 : i64
136+
%c0_i32 = arith.constant 0 : i32
137+
138+
%blk_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
139+
tt.store %blk_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
140+
"""
141+
else:
142+
store_ops = f"""
143+
%12 = tt.splat %dst : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
144+
%13 = tt.addptr %12, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
145+
tt.store %13, %store_val {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
146+
"""
147+
60148
ir = f"""
61-
#mma = {layout}
62-
#dot_a = #ttg.dot_op<{{opIdx = 0, parent = #mma, kWidth = {A_width}}}>
63-
#dot_b = #ttg.dot_op<{{opIdx = 1, parent = #mma, kWidth = {B_width}}}>
149+
#layout = {layout}
64150
module attributes {{{"ttig.support_sg_2d_block," if support_block_io else ""} "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, ttg.target = "xpu", "ttg.threads-per-warp" = {threads_per_warp} : i32}} {{
65-
tt.func public @tensor_pointer_block_store(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}, %arg3: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}) {{
66-
67-
// A matrix
68-
%stride_a = arith.constant dense<{N}> : tensor<{M}x1xi32, #dot_a>
69-
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_a}}>>
70-
%2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_a}}>> -> tensor<{M}x1xi32, #dot_a>
71-
%3 = arith.muli %2, %stride_a : tensor<{M}x1xi32, #dot_a>
72-
%4 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>>
73-
%5 = tt.expand_dims %4 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>> -> tensor<1x{N}xi32, #dot_a>
74-
%6 = tt.broadcast %3 : tensor<{M}x1xi32, #dot_a> -> tensor<{M}x{N}xi32, #dot_a>
75-
%7 = tt.broadcast %5 : tensor<1x{N}xi32, #dot_a> -> tensor<{M}x{N}xi32, #dot_a>
76-
%8 = arith.addi %6, %7 : tensor<{M}x{N}xi32, #dot_a>
77-
78-
%9 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
79-
%10 = tt.addptr %9, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>, tensor<{M}x{N}xi32, #dot_a>
80-
%11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
81-
%12 = tt.splat %arg1 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
82-
%13 = tt.addptr %12, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>, tensor<{M}x{N}xi32, #dot_a>
83-
tt.store %13, %11 {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
84-
85-
// B matrix
86-
%stride_b = arith.constant dense<{N}> : tensor<{M}x1xi32, #dot_b>
87-
%21 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>>
88-
%22 = tt.expand_dims %21 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>> -> tensor<{M}x1xi32, #dot_b>
89-
%23 = arith.muli %22, %stride_b : tensor<{M}x1xi32, #dot_b>
90-
%24 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>>
91-
%25 = tt.expand_dims %24 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>> -> tensor<1x{N}xi32, #dot_b>
92-
%26 = tt.broadcast %23 : tensor<{M}x1xi32, #dot_b> -> tensor<{M}x{N}xi32, #dot_b>
93-
%27 = tt.broadcast %25 : tensor<1x{N}xi32, #dot_b> -> tensor<{M}x{N}xi32, #dot_b>
94-
%28 = arith.addi %26, %27 : tensor<{M}x{N}xi32, #dot_b>
95-
96-
%29 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
97-
%30 = tt.addptr %29, %28 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>, tensor<{M}x{N}xi32, #dot_b>
98-
%31 = tt.load %30 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
99-
%32 = tt.splat %arg3 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
100-
%33 = tt.addptr %32, %28 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>, tensor<{M}x{N}xi32, #dot_b>
101-
tt.store %33, %31 {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
151+
tt.func public @block_store(%src: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %dst: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{
152+
153+
%stride = arith.constant dense<{N}> : tensor<{M}x1xi32, #layout>
154+
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #layout}}>>
155+
%2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> -> tensor<{M}x1xi32, #layout>
156+
%3 = arith.muli %2, %stride : tensor<{M}x1xi32, #layout>
157+
%4 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #layout}}>>
158+
%5 = tt.expand_dims %4 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #layout}}>> -> tensor<1x{N}xi32, #layout>
159+
%6 = tt.broadcast %3 : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
160+
%7 = tt.broadcast %5 : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
161+
%8 = arith.addi %6, %7 : tensor<{M}x{N}xi32, #layout>
162+
%9 = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
163+
%10 = tt.addptr %9, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
164+
%store_val = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
165+
166+
{store_ops}
102167
103168
tt.return
104169
}}
@@ -112,11 +177,10 @@ def test_tensor_pointer_block_store(M, N, dtype_str, layout, device, tmp_path: p
112177
a = torch.randint(low=-127, high=128, size=(M, N), dtype=torch_dtype, device=device)
113178

114179
x = torch.empty_like(a)
115-
y = torch.empty_like(a)
116180

117-
temp_file = tmp_path / "test_tensor_pointer_block_store.ttgir"
181+
temp_file = tmp_path / "test_block_store.ttgir"
118182
temp_file.write_text(ir)
119183
kernel = triton.compile(str(temp_file))
120184

121-
kernel[(1, 1, 1)](a, x, a, y)
122-
assert torch.equal(a, x) and torch.equal(a, y)
185+
kernel[(1, 1, 1)](a, x)
186+
assert torch.equal(a, x)

scripts/skiplist/lts/intel.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
python/test/unit/intel/test_block_load.py::test_block_load_dpas_layout
2-
python/test/unit/intel/test_block_store.py::test_tensor_pointer_block_store
2+
python/test/unit/intel/test_block_store.py::test_block_store

0 commit comments

Comments
 (0)