Skip to content

Commit 632d234

Browse files
authored
Implement make_opt_flags function for XPU and enable tests in test_matmul.py (#5051)
**Note for reviewers:** I've left only the most basic heuristics. If you have improvements in mind that will definitely work better without testing, i.e. were already known, then we can make edits directly to this pull request. If you have improvements that need to be tested - that's also good, please also write, but I'd prefer to implement the basic version as quickly as possible and tune it in separate PRs if possible. Pass rate: 84.11% -> 89.04% --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent a0e532a commit 632d234

File tree

9 files changed

+168
-28
lines changed

9 files changed

+168
-28
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# testing utilities
2121
from triton_kernels.testing import assert_close, compute_actual_scale
2222
# target-specific utilities
23-
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4
23+
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_xpu, is_hip_cdna4
2424

2525
# ---------------
2626
# initialize data
@@ -73,7 +73,7 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
7373
if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
7474
gs0 = None
7575
gs1 = None
76-
if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
76+
if is_cuda() and "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
7777
w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
7878
return x, w, bias, gs0, gs1
7979

@@ -294,6 +294,10 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
294294
if split_k > 1:
295295
pytest.skip("splitK hasn't been fully tested on AMD GPU.")
296296

297+
elif is_xpu():
298+
if split_k > 1:
299+
pytest.skip("FIXME: https://github.com/intel/intel-xpu-backend-for-triton/issues/5074")
300+
297301
if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3():
298302
pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform")
299303

@@ -308,20 +312,21 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
308312
pytest.skip("Non-scale swizzling not supported on CDNA4 yet")
309313
if n % 32 != 0 or k % (32 * 8) != 0:
310314
pytest.skip(f"Shape {m}x{n}x{k} is not supported for scale swizzling on AMD GPU")
311-
if torch.cuda.get_device_capability()[0] < 9:
312-
pytest.skip("NYI. Ampere swizzling.")
313-
if torch.cuda.get_device_capability()[0] < 10:
314-
if "mxfloat4" not in weight_dtype_str:
315-
pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.")
316-
if k % 64 != 0 or n % 64 != 0:
317-
# Automatic padding not implemented for Hopper swizzle
318-
pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
315+
if is_cuda():
316+
if torch.cuda.get_device_capability()[0] < 9:
317+
pytest.skip("NYI. Ampere swizzling.")
318+
if torch.cuda.get_device_capability()[0] < 10:
319+
if "mxfloat4" not in weight_dtype_str:
320+
pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.")
321+
if k % 64 != 0 or n % 64 != 0:
322+
# Automatic padding not implemented for Hopper swizzle
323+
pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
319324

320325
# launch metadata for batched / mx types may not work yet.
321326
torch.manual_seed(0)
322327

323328
block_k = None
324-
if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
329+
if is_cuda() and is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
325330
# Override block_k for testing correctness. The default is temporarily 128 for
326331
# performance reasons which doesn't work with persistent matmul.
327332
# TODO: revisit when Triton is better for H100 + MXFP4
@@ -436,7 +441,7 @@ def round_x(x, idx):
436441

437442
round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
438443
ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, #
439-
rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref)
444+
rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref, device=device)
440445
scale = lambda val, scal: val if scal is None else val / scal
441446
if n_expt_shards > 1:
442447
if do_scatter:
@@ -549,21 +554,21 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
549554
(4096, 4096, 0),
550555
])
551556
@pytest.mark.parametrize("view_x_as_zero_cols", [False, True])
552-
def test_zero_reduction_dim(m, n, k, view_x_as_zero_cols):
557+
def test_zero_reduction_dim(m, n, k, view_x_as_zero_cols, device):
553558
torch.manual_seed(0)
554559

555560
if view_x_as_zero_cols:
556-
x = torch.randn(m, m, device="cuda", dtype=torch.bfloat16)
561+
x = torch.randn(m, m, device=device, dtype=torch.bfloat16)
557562
x = x[:0, :].transpose(-1, -2)
558563
else:
559-
x = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
560-
w = torch.randn(k, n, device="cuda", dtype=torch.bfloat16)
561-
bias = torch.randn(n, device="cuda", dtype=torch.float32)
564+
x = torch.randn(m, k, device=device, dtype=torch.bfloat16)
565+
w = torch.randn(k, n, device=device, dtype=torch.bfloat16)
566+
bias = torch.randn(n, device=device, dtype=torch.float32)
562567

563568
try:
564569
tri_y = matmul_ogs(x, w, bias)
565570
except opt_flags.InapplicableConstraint:
566571
pytest.skip("inapplicable constraint")
567-
ref_y = matmul_ogs_torch(x, w, bias, round_x=lambda x, idx: x, round_y=lambda y: y)
572+
ref_y = matmul_ogs_torch(x, w, bias, round_x=lambda x, idx: x, round_y=lambda y: y, device=device)
568573

569574
assert_close(ref_y, tri_y)

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def matmul_ogs_torch(x, w, bias,
549549
betas = None,
550550
gammas = None,
551551
round_x = None, round_y = None,
552-
):
552+
device: str = "cuda"):
553553
is_input_batched = x.ndim == 3
554554
assert x.dtype.itemsize > 1
555555
assert w.dtype.itemsize > 1
@@ -588,7 +588,7 @@ def matmul_ogs_torch(x, w, bias,
588588
else:
589589
idx = gather_indx.src_indx[lo:hi] // n_expts_act
590590
batch = i if is_input_batched else 0
591-
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(),
591+
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=device)).float(),
592592
w[i].float())
593593
if bias is not None:
594594
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import triton
55
from triton_kernels.target_info import get_cdna_version
66
import torch
7-
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
7+
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia, opt_flags_intel
88

99

1010
@dataclass
@@ -30,6 +30,84 @@ def __post_init__(self):
3030
raise ValueError("Not supported")
3131

3232

33+
def make_default_opt_flags_intel(
34+
out_dtype,
35+
lhs_dtype,
36+
rhs_dtype,
37+
precision_config,
38+
m,
39+
n,
40+
k,
41+
routing_data,
42+
can_use_persistent_tma,
43+
can_use_fused_scatter,
44+
enforce_bitwise_invariance,
45+
epilogue_effective_itemsize,
46+
constraints,
47+
):
48+
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages"]
49+
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
50+
# tokens per expert
51+
if routing_data is None:
52+
tokens_per_expt = m
53+
elif routing_data.expected_tokens_per_expt is None:
54+
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
55+
else:
56+
tokens_per_expt = routing_data.expected_tokens_per_expt
57+
# pid swizzling
58+
group_m = 8
59+
xcd_swizzle = 1
60+
# block_m
61+
if constraints.get("block_m", None):
62+
block_m = constraints["block_m"]
63+
elif enforce_bitwise_invariance:
64+
block_m = 128
65+
else:
66+
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
67+
# block n
68+
block_n = opt_flags_intel.compute_block_n(n)
69+
# is_persistent
70+
is_persistent = constraints.get("is_persistent", False)
71+
# block k
72+
if constraints.get("block_k", None) is not None:
73+
block_k = constraints["block_k"]
74+
else:
75+
block_k = opt_flags_intel.compute_block_k(k, is_persistent, precision_config)
76+
# split_k
77+
if constraints.get("split_k", None) is not None:
78+
split_k = constraints["split_k"]
79+
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
80+
split_k = 1
81+
else:
82+
estimated_actual_grid_size = opt_flags_intel.compute_grid_size(None, m, n, block_m, block_n)
83+
split_k = opt_flags_intel.compute_split_k(block_k, k, estimated_actual_grid_size)
84+
85+
epilogue_subtile = constraints.get('epilogue_subtile', None)
86+
if epilogue_subtile is None:
87+
epilogue_subtile = 1
88+
89+
ret = OptFlags(
90+
block_m=block_m,
91+
block_n=block_n,
92+
block_k=block_k,
93+
num_warps=opt_flags_intel.compute_num_warps(block_m, block_n),
94+
num_stages=constraints.get("num_stages", 2),
95+
fused_scatter=constraints.get('fused_scatter', False),
96+
group_m=group_m,
97+
xcd_swizzle=xcd_swizzle,
98+
w_cache_modifier=None,
99+
split_k=split_k,
100+
is_persistent=is_persistent,
101+
epilogue_subtile=epilogue_subtile,
102+
arch=None,
103+
target_kernel_kwargs=dict(),
104+
idle_sms=0,
105+
)
106+
# check constraints
107+
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
108+
return ret
109+
110+
33111
def make_default_opt_flags_amd(
34112
out_dtype,
35113
lhs_dtype,
@@ -296,6 +374,8 @@ def make_opt_flags(
296374
enforce_bitwise_invariance, epilogue_effective_itemsize,
297375
_opt_flags_constraints]
298376
backend = triton.runtime.driver.active.get_current_target().backend
377+
if backend == "xpu":
378+
return make_default_opt_flags_intel(*args)
299379
if backend == "hip":
300380
return make_default_opt_flags_amd(*args)
301381
if backend == "cuda":
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
import triton
3+
4+
5+
def compute_grid_size(routing_data, m, n, block_m, block_n):
6+
if routing_data is not None:
7+
grid_m = routing_data.n_blocks(m, block_m)
8+
else:
9+
grid_m = triton.cdiv(m, block_m)
10+
grid_n = (n + block_n - 1) // block_n
11+
return grid_m * grid_n
12+
13+
14+
def compute_block_n(n: int):
15+
# block_n:
16+
return max(16, min(128, triton.next_power_of_2(n)))
17+
18+
19+
def compute_block_k(k: int | None, is_persistent: bool, precision_config):
20+
if k is not None:
21+
block_k = max(32, min(128, triton.next_power_of_2(k)))
22+
has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None
23+
if is_persistent and has_mx_weight_scale:
24+
block_k = min(block_k, 128)
25+
return block_k
26+
27+
28+
def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
29+
device_props = torch.xpu.get_device_properties(0)
30+
n_sms = device_props.gpu_subslice_count
31+
split_k = n_sms // grid_size
32+
if k is not None:
33+
# avoid split_k for small k
34+
num_block_k = triton.cdiv(k, block_k)
35+
split_k = min(split_k, num_block_k // 4)
36+
split_k = max(split_k, 1)
37+
return split_k
38+
39+
40+
def compute_num_warps(block_m, block_n):
41+
return max(block_m * block_n // 4096, 4)

python/triton_kernels/triton_kernels/target_info.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"has_native_mxfp",
1919
"is_cuda",
2020
"is_hip",
21+
"is_xpu",
2122
"is_hip_cdna3",
2223
"is_hip_cdna4",
2324
"is_xpu",
Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1-
tests/test_matmul.py::test_op
2-
tests/test_matmul.py::test_fused_act
3-
tests/test_matmul.py::test_zero_reduction_dim
1+
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5074
2+
tests/test_matmul.py::test_op[False-False-False-True-False-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
3+
tests/test_matmul.py::test_op[False-False-False-True-False-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
4+
tests/test_matmul.py::test_op[False-False-True-True-False-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
5+
tests/test_matmul.py::test_op[False-False-True-True-False-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
6+
tests/test_matmul.py::test_op[False-True-False-True-False-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
7+
tests/test_matmul.py::test_op[False-True-False-True-False-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
8+
tests/test_matmul.py::test_op[False-True-True-True-False-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
9+
tests/test_matmul.py::test_op[False-True-True-True-False-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1-
tests/test_matmul.py::test_op
2-
tests/test_matmul.py::test_fused_act
3-
tests/test_matmul.py::test_zero_reduction_dim
1+
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5074
2+
tests/test_matmul.py::test_op[False-False-False-True-False-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
3+
tests/test_matmul.py::test_op[False-False-False-True-False-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
4+
tests/test_matmul.py::test_op[False-False-True-True-False-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
5+
tests/test_matmul.py::test_op[False-False-True-True-False-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
6+
tests/test_matmul.py::test_op[False-True-False-True-False-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
7+
tests/test_matmul.py::test_op[False-True-False-True-False-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
8+
tests/test_matmul.py::test_op[False-True-True-True-False-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]
9+
tests/test_matmul.py::test_op[False-True-True-True-False-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-1-False-None]

scripts/test-triton.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ run_triton_kernels_tests() {
577577
cd $TRITON_PROJ/python/triton_kernels/tests
578578

579579
TRITON_TEST_SUITE=triton_kernels \
580-
run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:-8} --device xpu .
580+
run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:-4} --device xpu .
581581
}
582582

583583
test_triton() {

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class XPUOptions:
4242
generate_native_code: bool = False
4343
advanced_path: bool = False
4444
enable_tile_load_linear_layout: bool = True
45+
arch: str = None
4546
# FIXME: enable for XPU: https://github.com/intel/intel-xpu-backend-for-triton/issues/4954
4647
instrumentation_mode: str = ""
4748

0 commit comments

Comments
 (0)