Skip to content

Commit 1f0a363

Browse files
Post review comment for #4561 (#4634)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent a14e141 commit 1f0a363

File tree

1 file changed

+30
-32
lines changed

1 file changed

+30
-32
lines changed

python/test/unit/intel/test_block_store.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def warps_per_cta(layout):
4343
@pytest.mark.parametrize("M, N", [[M, N] for M, N in itertools.product([32, 64, 128, 256], [32, 64, 128, 256])])
4444
@pytest.mark.parametrize("dtype_str", ["float32", "float16", "int8"])
4545
@pytest.mark.parametrize("layout", layouts)
46-
@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend")
46+
@pytest.mark.skipif(not is_xpu(), reason="Block store tests are specific to the XPU backend")
4747
def test_tensor_pointer_block_store(M, N, dtype_str, layout, device, tmp_path: pathlib.Path):
4848

4949
warps = warps_per_cta(layout)
@@ -62,43 +62,43 @@ def test_tensor_pointer_block_store(M, N, dtype_str, layout, device, tmp_path: p
6262
#dot_a = #ttg.dot_op<{{opIdx = 0, parent = #mma, kWidth = {A_width}}}>
6363
#dot_b = #ttg.dot_op<{{opIdx = 1, parent = #mma, kWidth = {B_width}}}>
6464
module attributes {{{"ttig.support_sg_2d_block," if support_block_io else ""} "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, ttg.target = "xpu", "ttg.threads-per-warp" = {threads_per_warp} : i32}} {{
65-
tt.func public @tensor_pointer_block_load(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}, %arg3: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}) {{
65+
tt.func public @tensor_pointer_block_store(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}, %arg3: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}) {{
6666
6767
// A matrix
6868
%stride_a = arith.constant dense<{N}> : tensor<{M}x1xi32, #dot_a>
6969
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_a}}>>
7070
%2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_a}}>> -> tensor<{M}x1xi32, #dot_a>
71-
%4 = arith.muli %2, %stride_a : tensor<{M}x1xi32, #dot_a>
72-
%5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>>
73-
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>> -> tensor<1x{N}xi32, #dot_a>
74-
%7 = tt.broadcast %4 : tensor<{M}x1xi32, #dot_a> -> tensor<{M}x{N}xi32, #dot_a>
75-
%8 = tt.broadcast %6 : tensor<1x{N}xi32, #dot_a> -> tensor<{M}x{N}xi32, #dot_a>
76-
%9 = arith.addi %7, %8 : tensor<{M}x{N}xi32, #dot_a>
77-
78-
%10 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
79-
%11 = tt.addptr %10, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>, tensor<{M}x{N}xi32, #dot_a>
80-
%12 = tt.load %11 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
81-
%13 = tt.splat %arg1 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
82-
%14 = tt.addptr %13, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>, tensor<{M}x{N}xi32, #dot_a>
83-
tt.store %14, %12 {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
71+
%3 = arith.muli %2, %stride_a : tensor<{M}x1xi32, #dot_a>
72+
%4 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>>
73+
%5 = tt.expand_dims %4 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>> -> tensor<1x{N}xi32, #dot_a>
74+
%6 = tt.broadcast %3 : tensor<{M}x1xi32, #dot_a> -> tensor<{M}x{N}xi32, #dot_a>
75+
%7 = tt.broadcast %5 : tensor<1x{N}xi32, #dot_a> -> tensor<{M}x{N}xi32, #dot_a>
76+
%8 = arith.addi %6, %7 : tensor<{M}x{N}xi32, #dot_a>
77+
78+
%9 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
79+
%10 = tt.addptr %9, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>, tensor<{M}x{N}xi32, #dot_a>
80+
%11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
81+
%12 = tt.splat %arg1 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
82+
%13 = tt.addptr %12, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>, tensor<{M}x{N}xi32, #dot_a>
83+
tt.store %13, %11 {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
8484
8585
// B matrix
8686
%stride_b = arith.constant dense<{N}> : tensor<{M}x1xi32, #dot_b>
87-
%22 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>>
88-
%44 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>>
89-
%46 = tt.expand_dims %44 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>> -> tensor<{M}x1xi32, #dot_b>
90-
%49 = arith.muli %46, %stride_b : tensor<{M}x1xi32, #dot_b>
91-
%50 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>> -> tensor<1x{N}xi32, #dot_b>
92-
%51 = tt.broadcast %49 : tensor<{M}x1xi32, #dot_b> -> tensor<{M}x{N}xi32, #dot_b>
93-
%52 = tt.broadcast %50 : tensor<1x{N}xi32, #dot_b> -> tensor<{M}x{N}xi32, #dot_b>
94-
%53 = arith.addi %51, %52 : tensor<{M}x{N}xi32, #dot_b>
95-
96-
%54 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
97-
%55 = tt.addptr %54, %53 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>, tensor<{M}x{N}xi32, #dot_b>
98-
%56 = tt.load %55 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
99-
%57 = tt.splat %arg3 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
100-
%58 = tt.addptr %57, %53 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>, tensor<{M}x{N}xi32, #dot_b>
101-
tt.store %58, %56 {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
87+
%21 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>>
88+
%22 = tt.expand_dims %21 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>> -> tensor<{M}x1xi32, #dot_b>
89+
%23 = arith.muli %22, %stride_b : tensor<{M}x1xi32, #dot_b>
90+
%24 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>>
91+
%25 = tt.expand_dims %24 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>> -> tensor<1x{N}xi32, #dot_b>
92+
%26 = tt.broadcast %23 : tensor<{M}x1xi32, #dot_b> -> tensor<{M}x{N}xi32, #dot_b>
93+
%27 = tt.broadcast %25 : tensor<1x{N}xi32, #dot_b> -> tensor<{M}x{N}xi32, #dot_b>
94+
%28 = arith.addi %26, %27 : tensor<{M}x{N}xi32, #dot_b>
95+
96+
%29 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
97+
%30 = tt.addptr %29, %28 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>, tensor<{M}x{N}xi32, #dot_b>
98+
%31 = tt.load %30 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
99+
%32 = tt.splat %arg3 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
100+
%33 = tt.addptr %32, %28 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>, tensor<{M}x{N}xi32, #dot_b>
101+
tt.store %33, %31 {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
102102
103103
tt.return
104104
}}
@@ -120,5 +120,3 @@ def test_tensor_pointer_block_store(M, N, dtype_str, layout, device, tmp_path: p
120120

121121
kernel[(1, 1, 1)](a, x, a, y)
122122
assert torch.equal(a, x) and torch.equal(a, y)
123-
124-
temp_file.unlink()

0 commit comments

Comments
 (0)