Skip to content

Commit ba1a5b7

Browse files
committed
Separated optimal DPAS layout calculation logic, exposed that for python, enabled all gemm benchmark cases for gluon
1 parent 2aa5ff7 commit ba1a5b7

File tree

8 files changed

+438
-195
lines changed

8 files changed

+438
-195
lines changed

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 39 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from triton_kernels_benchmark import xetla_kernel
2222
from triton_kernels_benchmark import cutlass_kernel
2323

24+
from utils.dpas_layout_analyzer import calculate_optimal_warps_per_cta, calculate_optimal_rep_clusters
25+
2426

2527
def get_matmul_autotune_configs() -> List[triton.Config]:
2628
configs = [
@@ -178,32 +180,20 @@ def get_gluon_matmul_autotune_configs(base_configs_fn: Callable) -> List[triton.
178180
# Append additional meta parameters needed for gluon kernel
179181
# To determine prefetch distance and DPAS layout
180182
{**config.kwargs, 'NUM_STAGES': config.num_stages, 'NUM_WARPS': config.num_warps},
181-
num_stages=config.num_stages,
182-
num_warps=config.num_warps
183-
)
184-
for config in base_configs
183+
num_stages=config.num_stages, num_warps=config.num_warps) for config in base_configs
185184
]
186185

187186

188187
@gluon.constexpr_function
189-
def get_dpas_layout(num_warps: ttgl.constexpr) -> ttgl.constexpr:
190-
# TODO: return same DPAS layout as calculated by passes for triton
191-
warps_per_cta = [2, 2]
192-
if num_warps == 16:
193-
warps_per_cta = [4, 4]
194-
if num_warps == 32:
195-
warps_per_cta = [4, 8]
196-
elif num_warps == 64:
197-
warps_per_cta = [8, 8]
188+
def get_dpas_layout(num_warps: ttgl.constexpr, m_shape: ttgl.constexpr, n_shape: ttgl.constexpr,
189+
k_shape: ttgl.constexpr) -> ttgl.constexpr:
190+
threads_per_warp = 16
191+
warps_per_cta = calculate_optimal_warps_per_cta(num_warps, m_shape, n_shape)
192+
198193
return IntelDPASLayout(
199-
repeatCount=8,
200-
systolic_depth=8,
201-
execution_size=16,
202-
ops_per_chan=2,
203-
warps_per_cta=warps_per_cta,
204-
rep_cluster=[4, 2],
205-
threads_per_warp=16
206-
)
194+
repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, warps_per_cta=warps_per_cta,
195+
rep_cluster=calculate_optimal_rep_clusters(m_shape, n_shape, k_shape, threads_per_warp,
196+
warps_per_cta), threads_per_warp=threads_per_warp)
207197

208198

209199
@triton.autotune(
@@ -217,16 +207,14 @@ def gluon_matmul_kernel_dpas_tensor_desc(
217207
# Matrix dimensions
218208
M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr,
219209
# Stride variables
220-
stride_am: ttgl.constexpr, stride_ak: ttgl.constexpr,
221-
stride_bk: ttgl.constexpr, stride_bn: ttgl.constexpr,
210+
stride_am: ttgl.constexpr, stride_ak: ttgl.constexpr, stride_bk: ttgl.constexpr, stride_bn: ttgl.constexpr,
222211
stride_cm: ttgl.constexpr, stride_cn: ttgl.constexpr,
223212
# Meta parameters
224213
BLOCK_SIZE_M: ttgl.constexpr, BLOCK_SIZE_N: ttgl.constexpr, BLOCK_SIZE_K: ttgl.constexpr,
225214
GROUP_SIZE_M: ttgl.constexpr,
226215
# Gluon meta parameters
227216
NUM_STAGES: ttgl.constexpr, NUM_WARPS: ttgl.constexpr):
228-
layout: ttgl.constexpr = get_dpas_layout(NUM_WARPS)
229-
217+
layout: ttgl.constexpr = get_dpas_layout(NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K)
230218

231219
lhs_layout: ttgl.constexpr = ttgl.DotOperandLayout(parent=layout, operand_index=0, k_width=1)
232220
rhs_layout: ttgl.constexpr = ttgl.DotOperandLayout(parent=layout, operand_index=1, k_width=2)
@@ -241,19 +229,19 @@ def gluon_matmul_kernel_dpas_tensor_desc(
241229
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
242230
pid_n = (pid % num_pid_in_group) // group_size_m
243231

244-
a_desc = ttgl.intel.xpu.xe.make_tensor_descriptor(a_ptr, (M, K), (stride_am, stride_ak), (BLOCK_SIZE_M, BLOCK_SIZE_K),
245-
lhs_layout)
246-
b_desc = ttgl.intel.xpu.xe.make_tensor_descriptor(b_ptr, (K, N), (stride_bk, stride_bn), (BLOCK_SIZE_K, BLOCK_SIZE_N),
247-
rhs_layout)
248-
c_desc = ttgl.intel.xpu.xe.make_tensor_descriptor(c_ptr, (M, N), (stride_cm, stride_cn), (BLOCK_SIZE_M, BLOCK_SIZE_N), layout)
232+
a_desc = ttgl.intel.xpu.xe.make_tensor_descriptor(a_ptr, (M, K), (stride_am, stride_ak),
233+
(BLOCK_SIZE_M, BLOCK_SIZE_K), lhs_layout)
234+
b_desc = ttgl.intel.xpu.xe.make_tensor_descriptor(b_ptr, (K, N), (stride_bk, stride_bn),
235+
(BLOCK_SIZE_K, BLOCK_SIZE_N), rhs_layout)
236+
c_desc = ttgl.intel.xpu.xe.make_tensor_descriptor(c_ptr, (M, N), (stride_cm, stride_cn),
237+
(BLOCK_SIZE_M, BLOCK_SIZE_N), layout)
249238

250239
# Clear accumulator
251240
zero_tensor = ttgl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ttgl.float32, layout=layout)
252241
c_desc.store_2d([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], zero_tensor)
253242

254243
accumulator = c_desc.load_2d([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N])
255244

256-
257245
# Prefetch first blocks for A and B matrices (pre-loop prefetches)
258246
for i in range(NUM_STAGES):
259247
if i * BLOCK_SIZE_K < K:
@@ -286,15 +274,15 @@ def gluon_matmul_kernel_dpas_tensor_desc_batched(
286274
# Matrix dimensions
287275
B: ttgl.constexpr, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr,
288276
# Stride variables
289-
stride_az: ttgl.constexpr, stride_am: ttgl.constexpr, stride_ak: ttgl.constexpr,
290-
stride_bz: ttgl.constexpr, stride_bk: ttgl.constexpr, stride_bn: ttgl.constexpr,
291-
stride_cz: ttgl.constexpr, stride_cm: ttgl.constexpr, stride_cn: ttgl.constexpr,
277+
stride_az: ttgl.constexpr, stride_am: ttgl.constexpr, stride_ak: ttgl.constexpr, stride_bz: ttgl.constexpr,
278+
stride_bk: ttgl.constexpr, stride_bn: ttgl.constexpr, stride_cz: ttgl.constexpr, stride_cm: ttgl.constexpr,
279+
stride_cn: ttgl.constexpr,
292280
# Meta parameters
293281
BLOCK_SIZE_M: ttgl.constexpr, BLOCK_SIZE_N: ttgl.constexpr, BLOCK_SIZE_K: ttgl.constexpr,
294282
GROUP_SIZE_M: ttgl.constexpr,
295283
# Gluon meta parameters
296284
NUM_STAGES: ttgl.constexpr, NUM_WARPS: ttgl.constexpr):
297-
layout: ttgl.constexpr = get_dpas_layout(NUM_WARPS)
285+
layout: ttgl.constexpr = get_dpas_layout(NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K)
298286

299287
lhs_layout: ttgl.constexpr = ttgl.DotOperandLayout(parent=layout, operand_index=0, k_width=1)
300288
rhs_layout: ttgl.constexpr = ttgl.DotOperandLayout(parent=layout, operand_index=1, k_width=2)
@@ -315,18 +303,12 @@ def gluon_matmul_kernel_dpas_tensor_desc_batched(
315303
offset_b = bid.to(ttgl.int64) * stride_bz
316304
offset_c = bid.to(ttgl.int64) * stride_cz
317305

318-
a_desc = ttgl.intel.xpu.xe.make_tensor_descriptor(
319-
a_ptr + offset_a, (M, K), (stride_am, stride_ak),
320-
(BLOCK_SIZE_M, BLOCK_SIZE_K), lhs_layout
321-
)
322-
b_desc = ttgl.intel.xpu.xe.make_tensor_descriptor(
323-
b_ptr + offset_b, (K, N), (stride_bk, stride_bn),
324-
(BLOCK_SIZE_K, BLOCK_SIZE_N), rhs_layout
325-
)
326-
c_desc = ttgl.intel.xpu.xe.make_tensor_descriptor(
327-
c_ptr + offset_c, (M, N), (stride_cm, stride_cn),
328-
(BLOCK_SIZE_M, BLOCK_SIZE_N), layout
329-
)
306+
a_desc = ttgl.intel.xpu.xe.make_tensor_descriptor(a_ptr + offset_a, (M, K), (stride_am, stride_ak),
307+
(BLOCK_SIZE_M, BLOCK_SIZE_K), lhs_layout)
308+
b_desc = ttgl.intel.xpu.xe.make_tensor_descriptor(b_ptr + offset_b, (K, N), (stride_bk, stride_bn),
309+
(BLOCK_SIZE_K, BLOCK_SIZE_N), rhs_layout)
310+
c_desc = ttgl.intel.xpu.xe.make_tensor_descriptor(c_ptr + offset_c, (M, N), (stride_cm, stride_cn),
311+
(BLOCK_SIZE_M, BLOCK_SIZE_N), layout)
330312

331313
# Clear accumulator
332314
zero_tensor = ttgl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ttgl.float32, layout=layout)
@@ -386,20 +368,12 @@ def matmul(
386368
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
387369
B,
388370
)
389-
matmul_kernel_batched[grid](
390-
a, b, c, #
391-
B, M, N, K, #
392-
a.stride(0), a.stride(a_major), a.stride(a_minor), #
393-
b.stride(0), b.stride(b_minor), b.stride(b_major), #
394-
c.stride(0), c.stride(1), c.stride(2))
371+
matmul_kernel_batched[grid](a, b, c, B, M, N, K, a.stride(0), a.stride(a_major), a.stride(a_minor), b.stride(0),
372+
b.stride(b_minor), b.stride(b_major), c.stride(0), c.stride(1), c.stride(2))
395373
elif len(a.shape) == 2 and len(b.shape) == 2:
396374
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
397-
matmul_kernel[grid](
398-
a, b, c, #
399-
M, N, K, #
400-
a.stride(a_major), a.stride(a_minor), #
401-
b.stride(b_minor), b.stride(b_major), #
402-
c.stride(0), c.stride(1))
375+
matmul_kernel[grid](a, b, c, M, N, K, a.stride(a_major), a.stride(a_minor), b.stride(b_minor),
376+
b.stride(b_major), c.stride(0), c.stride(1))
403377
else:
404378
assert False, 'Input matrixs dimensions mismatch'
405379
return c
@@ -459,7 +433,7 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
459433
[4, 32768, 4096, 128],
460434
[32, 4096, 128, 4096],
461435
[4096, 8, 128, 16384],
462-
# [4096, 8, 16384, 128], # TODO: mismatches for gluon
436+
[4096, 8, 16384, 128],
463437
]
464438

465439
DEVICE_NAME = torch.xpu.get_device_name()
@@ -498,13 +472,13 @@ def get_benchmark(
498472
supported_providers = {
499473
'gluon': 'Gluon',
500474
'triton': 'Triton',
501-
'onednn': 'OneDNN',
475+
#'onednn': 'OneDNN',
502476
}
503477
# use_cutlass
504-
if not (transpose_a or transpose_b):
505-
if torch.xpu.get_device_name() != 'Intel(R) Arc(TM) Graphics':
506-
# FIXME: enable cutlass on LNL
507-
supported_providers['cutlass'] = 'CUTLASS'
478+
# if not (transpose_a or transpose_b):
479+
# if torch.xpu.get_device_name() != 'Intel(R) Arc(TM) Graphics':
480+
# # FIXME: enable cutlass on LNL
481+
# supported_providers['cutlass'] = 'CUTLASS'
508482
providers = benchmark_suite.filter_providers(supported_providers, providers_filter)
509483

510484
# Benchmark Performance
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from functools import wraps
2+
from triton._C.libtriton import intel
3+
4+
from triton.experimental.gluon.language.intel.xpu.xe import get_dpas_capabilities
5+
from triton.language.core import TRITON_BUILTIN
6+
7+
8+
def allow_in_kernel(fn):
9+
"""Mark a function as a builtin."""
10+
assert callable(fn)
11+
12+
@wraps(fn)
13+
def wrapper(*args, **kwargs):
14+
return fn(*args, **kwargs)
15+
16+
setattr(wrapper, TRITON_BUILTIN, True)
17+
18+
return wrapper
19+
20+
21+
@allow_in_kernel
22+
def calculate_optimal_warps_per_cta(num_warps, m_shape, n_shape):
23+
ret_shape = [m_shape, n_shape]
24+
dpas_cap = get_dpas_capabilities()
25+
return intel.calculate_warps_per_tile(capRepeatCount=dpas_cap['repeatCount'],
26+
capExecutionSize=dpas_cap['executionSize'], shape=ret_shape,
27+
numWarps=num_warps)
28+
29+
30+
@allow_in_kernel
31+
def calculate_optimal_rep_clusters(block_m, block_n, block_k, threads_per_warp, warps_per_cta):
32+
dtype_bitwidth = 16 # bf16 TODO: auto detect
33+
is_fp8 = dtype_bitwidth == 8
34+
dpas_cap = get_dpas_capabilities()
35+
cap_repeat_count = dpas_cap['repeatCount']
36+
cap_systolic_depth = dpas_cap['systolicDepth']
37+
cap_execution_size = dpas_cap['executionSize']
38+
ops_per_chan = int(dpas_cap['opsChanBitWidths'] / dtype_bitwidth)
39+
40+
ret_shape = [block_m, block_n]
41+
a_shape = [block_m, block_k]
42+
b_shape = [block_k, block_n]
43+
44+
rep_cluster = intel.calculate_rep_cluster(cap_repeat_count=cap_repeat_count, cap_systolic_depth=cap_systolic_depth,
45+
cap_execution_size=cap_execution_size, ops_per_chan=ops_per_chan,
46+
ret_shape=ret_shape, threads_per_warp=threads_per_warp,
47+
a_bitwidth=dtype_bitwidth, is_fp8=is_fp8, a_shape=a_shape,
48+
b_shape=b_shape, warps_per_tile=warps_per_cta)
49+
50+
return rep_cluster

python/triton/experimental/gluon/language/intel/xpu/xe.py

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,42 @@
22

33
from typing import List, Tuple, Sequence
44
from dataclasses import dataclass
5+
from functools import cache
56

67
import triton.experimental.gluon.language._core as ttgl
78
from triton.experimental.gluon.language._layouts import DotOperandLayout
89
from triton.experimental.gluon.language.intel._layouts import IntelDPASLayout
910
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
1011
from triton.language.core import ir, constexpr, tensor_descriptor_base, block_type, tensor, tuple
1112

12-
# load_tensor_descriptor = builtin(tl_core.load_tensor_descriptor)
13-
# store_tensor_descriptor = builtin(tl_core.store_tensor_descriptor)
14-
1513
__all__ = ["make_tensor_descriptor", "dot_fma"]
1614

1715

16+
@cache
17+
def get_dpas_capabilities():
18+
from triton.backends.intel.driver import XPUDriver
19+
20+
driver = XPUDriver()
21+
target = driver.get_current_target()
22+
properties = target.arch
23+
24+
# like annotate_module in passes
25+
dpas_cap = {
26+
"systolicDepth": 8,
27+
"repeatCount": 8,
28+
"executionSize": min(properties.get("sub_group_sizes", [16])),
29+
"opsChanBitWidths": 32,
30+
"has_subgroup_2d_block_io": properties.get("has_subgroup_2d_block_io", False),
31+
}
32+
33+
return dpas_cap
34+
35+
36+
def is_2d_block_supported():
37+
capabilities = get_dpas_capabilities()
38+
return capabilities["has_subgroup_2d_block_io"]
39+
40+
1841
class tensor_descriptor(tensor_descriptor_base):
1942
"""A descriptor representing a tensor in global memory."""
2043

@@ -36,68 +59,55 @@ def _flatten_ir(self, handles: List[ir.value]) -> None:
3659
self.shape._flatten_ir(handles)
3760
self.strides._flatten_ir(handles)
3861

62+
def mark_2d_block_attribute(self, op, order, _semantic):
63+
if order not in ('row_major', 'column_major'):
64+
raise ValueError("Only row_major/column_major order is supported for 2d block")
65+
66+
attr = _semantic.builder.get_string_attr(order)
67+
op.set_attr("ttig.block_io", attr)
68+
3969
@builtin
4070
def load(self, offsets: Sequence[constexpr | tensor], _semantic=None) -> tensor:
4171
return _semantic.descriptor_load(self, offsets, "", "")
4272

43-
def load_2d(self, offsets: Sequence[constexpr | tensor], is_2d_block=False, _semantic=None) -> tensor:
44-
# TODO: MaterializeBlockPointers.cpp
45-
# Add 2d_block_io parameter + validation to set proper attribute
46-
# Validation: (?)
47-
# > 2 dims
48-
# > stride 16 bytes aligned
49-
# and others
73+
@builtin
74+
def load_2d(self, offsets: Sequence[constexpr | tensor], order: str = "row_major", _semantic=None) -> tensor:
75+
if not is_2d_block_supported():
76+
raise ValueError("2d block functionality is not supported for this hardware")
5077

5178
op = _semantic.descriptor_load(self, offsets, "", "")
52-
53-
# TODO: proper handling like below test example
54-
# Option to set row/column major and other params
55-
attr = _semantic.builder.get_string_attr("row_major")
56-
op.handle.set_attr("ttig.block_io", attr)
57-
79+
self.mark_2d_block_attribute(op.handle, order, _semantic)
5880
return op
5981

6082
@builtin
6183
def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
6284
return _semantic.descriptor_store(self, value, offsets)
6385

6486
@builtin
65-
def store_2d(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
66-
op = _semantic.descriptor_store(self, value, offsets)
67-
68-
attr = _semantic.builder.get_string_attr("row_major")
69-
op.handle.set_attr("ttig.block_io", attr)
87+
def store_2d(self, offsets: Sequence[constexpr | tensor], value: tensor, order: str = "row_major",
88+
_semantic=None) -> tensor:
89+
if not is_2d_block_supported():
90+
raise ValueError("2d block functionality is not supported for this hardware")
7091

92+
op = _semantic.descriptor_store(self, value, offsets)
93+
self.mark_2d_block_attribute(op.handle, order, _semantic)
7194
return op
7295

7396
@builtin
74-
def prefetch(self, offsets: Sequence[constexpr | tensor], mask=None, cache=None, evict=None, is_volatile=False, _semantic=None):
97+
def prefetch(self, offsets: Sequence[constexpr | tensor], _semantic=None):
7598
ptr_handle = self.handle
7699
offsets_handles = [offset.handle if hasattr(offset, 'handle') else offset for offset in offsets]
77100
return _semantic.builder.create_prefetch(ptr_handle, offsets_handles, False)
78101

79102
@builtin
80-
def prefetch_2d(self, offsets: Sequence[constexpr | tensor], mask=None, cache=None, evict=None, is_volatile=False, _semantic=None):
81-
# TODO: handle other ttig.prefetch params
82-
# ptr is just temporary, support for tensor descriptor is needed
83-
# calculate offsets like tt.advance
84-
# maybe add support for mask, seems optional
85-
# also 2d block attr and others
86-
#return _semantic.builder.create_prefetch(ptr.handle, False)
87-
"""
88-
pyton/triton/language/semantic.py @ load:1077 (TritonSemantic)
89-
cache_modifier: str, eviction_policy: str
90-
cache = self._str_to_load_cache_modifier(cache_modifier)
91-
eviction = self._str_to_eviction_policy(eviction_policy)
92-
"""
103+
def prefetch_2d(self, offsets: Sequence[constexpr | tensor], order: str = "row_major", _semantic=None):
104+
if not is_2d_block_supported():
105+
raise ValueError("2d block functionality is not supported for this hardware")
93106

94107
ptr_handle = self.handle
95108
offsets_handles = [offset.handle if hasattr(offset, 'handle') else offset for offset in offsets]
96109
op = _semantic.builder.create_prefetch(ptr_handle, offsets_handles, False)
97-
98-
attr = _semantic.builder.get_string_attr("row_major")
99-
op.set_attr("ttig.block_io", attr)
100-
110+
self.mark_2d_block_attribute(op, order, _semantic)
101111
return op
102112

103113

0 commit comments

Comments
 (0)