Skip to content

Commit 244c285

Browse files
chengjunluetiotto
andauthored
[BACKEND] Enhance the 2D Block io lowering for tt.store. (#4561)
Refactor the 2D block IO store lowering. 1. Use one 2D block IO store lowering pattern with higher priority for the `tt.store` with either regular pointer or block pointer. 2. Use the linear layout utils to get the block IO tile shape. To support more variant of layouts in addition to the DPAS layout. 3. Fix the flaky that not checking the memory contiguous on store pointer. 4. Add boundary protection support for block pointer. --------- Signed-off-by: Lu,Chengjun <[email protected]> Signed-off-by: Tiotto, Ettore <[email protected]> Co-authored-by: Tiotto, Ettore <[email protected]>
1 parent dcc6db9 commit 244c285

File tree

4 files changed

+783
-406
lines changed

4 files changed

+783
-406
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import itertools
2+
3+
import numpy as np
4+
import pytest
5+
import torch
6+
import pathlib
7+
8+
import triton
9+
from triton._internal_testing import is_xpu
10+
11+
12+
class DpasLayout:
13+
14+
def __init__(self, repeatCount, systolic_depth, execution_size, ops_per_chan, threads_per_warp, warps_per_cta,
15+
rep_cluster):
16+
self.repeatCount = repeatCount
17+
self.systolic_depth = systolic_depth
18+
self.execution_size = execution_size
19+
self.ops_per_chan = ops_per_chan
20+
self.threads_per_warp = threads_per_warp
21+
self.warps_per_cta = warps_per_cta
22+
self.rep_cluster = rep_cluster
23+
24+
def __str__(self):
25+
return f"#ttig.dpas<{{repeatCount={self.repeatCount}, systolicDepth={self.systolic_depth}, executionSize = {self.execution_size}, opsPerChan = {self.ops_per_chan}, threadsPerWarp = {self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, repCluster={self.rep_cluster}}}>"
26+
27+
28+
def warps_per_cta(layout):
29+
return layout.warps_per_cta
30+
31+
32+
layouts = [
33+
# Layout for Xe
34+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=4, threads_per_warp=16,
35+
warps_per_cta=[1, 4], rep_cluster=[1, 2]),
36+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16,
37+
warps_per_cta=[8, 4], rep_cluster=[4, 2]),
38+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=1, threads_per_warp=16,
39+
warps_per_cta=[8, 4], rep_cluster=[1, 1]),
40+
]
41+
42+
43+
@pytest.mark.parametrize("M, N", [[M, N] for M, N in itertools.product([32, 64, 128, 256], [32, 64, 128, 256])])
44+
@pytest.mark.parametrize("dtype_str", ["float32", "float16", "int8"])
45+
@pytest.mark.parametrize("layout", layouts)
46+
@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend")
47+
def test_tensor_pointer_block_store(M, N, dtype_str, layout, device, tmp_path: pathlib.Path):
48+
49+
warps = warps_per_cta(layout)
50+
num_warps = int(np.prod(warps))
51+
threads_per_warp = layout.threads_per_warp
52+
ops_per_chan = layout.ops_per_chan
53+
A_width = 1 if ops_per_chan == 1 else ops_per_chan // 2
54+
B_width = ops_per_chan
55+
56+
ty = {"float32": "f32", "float16": "f16", "bfloat16": "i16", "int8": "i8"}[dtype_str]
57+
58+
support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
59+
60+
ir = f"""
61+
#mma = {layout}
62+
#dot_a = #ttg.dot_op<{{opIdx = 0, parent = #mma, kWidth = {A_width}}}>
63+
#dot_b = #ttg.dot_op<{{opIdx = 1, parent = #mma, kWidth = {B_width}}}>
64+
module attributes {{{"ttig.support_sg_2d_block," if support_block_io else ""} "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, ttg.target = "xpu", "ttg.threads-per-warp" = {threads_per_warp} : i32}} {{
65+
tt.func public @tensor_pointer_block_load(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}, %arg3: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}) {{
66+
67+
// A matrix
68+
%stride_a = arith.constant dense<{N}> : tensor<{M}x1xi32, #dot_a>
69+
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_a}}>>
70+
%2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_a}}>> -> tensor<{M}x1xi32, #dot_a>
71+
%4 = arith.muli %2, %stride_a : tensor<{M}x1xi32, #dot_a>
72+
%5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>>
73+
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>> -> tensor<1x{N}xi32, #dot_a>
74+
%7 = tt.broadcast %4 : tensor<{M}x1xi32, #dot_a> -> tensor<{M}x{N}xi32, #dot_a>
75+
%8 = tt.broadcast %6 : tensor<1x{N}xi32, #dot_a> -> tensor<{M}x{N}xi32, #dot_a>
76+
%9 = arith.addi %7, %8 : tensor<{M}x{N}xi32, #dot_a>
77+
78+
%10 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
79+
%11 = tt.addptr %10, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>, tensor<{M}x{N}xi32, #dot_a>
80+
%12 = tt.load %11 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
81+
%13 = tt.splat %arg1 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
82+
%14 = tt.addptr %13, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>, tensor<{M}x{N}xi32, #dot_a>
83+
tt.store %14, %12 {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>
84+
85+
// B matrix
86+
%stride_b = arith.constant dense<{N}> : tensor<{M}x1xi32, #dot_b>
87+
%22 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>>
88+
%44 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>>
89+
%46 = tt.expand_dims %44 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>> -> tensor<{M}x1xi32, #dot_b>
90+
%49 = arith.muli %46, %stride_b : tensor<{M}x1xi32, #dot_b>
91+
%50 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>> -> tensor<1x{N}xi32, #dot_b>
92+
%51 = tt.broadcast %49 : tensor<{M}x1xi32, #dot_b> -> tensor<{M}x{N}xi32, #dot_b>
93+
%52 = tt.broadcast %50 : tensor<1x{N}xi32, #dot_b> -> tensor<{M}x{N}xi32, #dot_b>
94+
%53 = arith.addi %51, %52 : tensor<{M}x{N}xi32, #dot_b>
95+
96+
%54 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
97+
%55 = tt.addptr %54, %53 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>, tensor<{M}x{N}xi32, #dot_b>
98+
%56 = tt.load %55 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
99+
%57 = tt.splat %arg3 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
100+
%58 = tt.addptr %57, %53 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>, tensor<{M}x{N}xi32, #dot_b>
101+
tt.store %58, %56 {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>
102+
103+
tt.return
104+
}}
105+
}}
106+
"""
107+
108+
torch_dtype = getattr(torch, dtype_str)
109+
if torch_dtype.is_floating_point:
110+
a = torch.randn((M, N), dtype=torch_dtype, device=device)
111+
else:
112+
a = torch.randint(low=-127, high=128, size=(M, N), dtype=torch_dtype, device=device)
113+
114+
x = torch.empty_like(a)
115+
y = torch.empty_like(a)
116+
117+
temp_file = tmp_path / "test_tensor_pointer_block_store.ttgir"
118+
temp_file.write_text(ir)
119+
kernel = triton.compile(str(temp_file))
120+
121+
kernel[(1, 1, 1)](a, x, a, y)
122+
assert torch.equal(a, x) and torch.equal(a, y)
123+
124+
temp_file.unlink()

scripts/skiplist/lts/intel.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
python/test/unit/intel/test_block_load.py::test_block_load_dpas_layout
2+
python/test/unit/intel/test_block_store.py::test_tensor_pointer_block_store

0 commit comments

Comments
 (0)