Skip to content

Commit 83eb05c

Browse files
wdziurdzanmyachev
authored andcommitted
Align behavior with CUDA/HIP: skip test_matmul when swiglu_opts is not None and do_gamma is set
Signed-off-by: Witold Dziurdz <[email protected]> (cherry picked from commit 1479afd) Signed-off-by: Anatoly Myachev <[email protected]>
1 parent af6e538 commit 83eb05c

File tree

7 files changed

+19
-81
lines changed

7 files changed

+19
-81
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# testing utilities
1717
from triton_kernels.testing import assert_close, make_random_tensor
1818
# target-specific utilities
19-
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4
19+
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4, is_xpu
2020
from triton_kernels.swiglu import swiglu, swiglu_fn
2121
from triton_kernels.swiglu import PrecisionConfig as SwiGLUPrecisionConfig
2222

@@ -243,6 +243,10 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
243243
if split_k is not None and split_k > 1:
244244
pytest.skip("splitK hasn't been fully tested on AMD GPU.")
245245

246+
elif is_xpu():
247+
if swiglu_opts is not None and do_gamma:
248+
pytest.xfail("NYI: swiglu and gamma not supported together")
249+
246250
if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3():
247251
pytest.xfail("float8_e4m3fnuz only tested on AMD CDNA3 Platform")
248252

@@ -276,12 +280,12 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
276280
if hbm_swizzling:
277281
pytest.skip("NYI: nner_expt_opt and HBM swizzling")
278282
if not colmajor_mxfp_weight:
279-
if torch.cuda.get_device_capability()[0] < 10:
283+
if is_cuda() and torch.cuda.get_device_capability()[0] < 10:
280284
pytest.skip("transposed mxfp weight not supported with cuda capability < 10")
281285
if block_m == 16:
282286
pytest.skip("PassManager::run failed from Triton compiler")
283287
# TODO: should construct the test case differently rather than overriding here
284-
if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 10:
288+
if "float8" in weight_dtype_str and is_cuda() and torch.cuda.get_device_capability()[0] < 10:
285289
b_transpose = True
286290

287291
torch.manual_seed(0)

python/triton_kernels/tests/test_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def plus_a_reduce(x, a):
6060
@pytest.mark.parametrize("dim", [0, 1, 2])
6161
def test_op(B, M, N, dtype_str, dim, mask_mode, postprocess_fn, device):
6262
is_hip = triton.runtime.driver.active.get_current_target().backend == "hip"
63-
is_pre_h100 = torch.cuda.is_available() and torch.cuda.get_device_capability() < (9, 0)
63+
is_pre_h100 = device == "cuda" and torch.cuda.is_available() and torch.cuda.get_device_capability() < (9, 0)
6464
if (is_hip or is_pre_h100) and "float8" in dtype_str:
6565
pytest.skip("float8 not supported on CUDA < 9.0")
6666
torch.manual_seed(0)

python/triton_kernels/triton_kernels/matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def matmul(a, b, bias,
318318
)
319319
# there seems to be a bug on A100
320320
# pytest -vs test_matmul.py::test_op[False-False-False-False-pad_b-16-768-512-1024-ragged-float16-float16-10-1-False-None-False-False-False-True-None]
321-
if ragged_dimension == "K" and torch.cuda.get_device_capability()[0] < 9:
321+
if ragged_dimension == "K" and is_cuda() and torch.cuda.get_device_capability()[0] < 9:
322322
opt_flags.num_stages = 1
323323
if ragged_dimension == "K":
324324
a_has_tma = opt_flags.is_persistent and (a.stride(-1) != 1 or (a_ragged_metadata.slice_sizes_divisibility is not None))

python/triton_kernels/triton_kernels/matmul_details/opt_flags.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def make_default_opt_flags_intel(
5454
m,
5555
n,
5656
k,
57-
routing_data,
57+
ragged_metadata,
5858
can_use_persistent_tma,
5959
can_use_split_k,
6060
enforce_bitwise_invariance,
@@ -65,13 +65,13 @@ def make_default_opt_flags_intel(
6565
):
6666
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "epilogue_subtile", "num_stages", "max_allowable_mn"]
6767
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
68-
# tokens per expert
69-
if routing_data is None:
70-
tokens_per_expt = m
71-
elif routing_data.expected_tokens_per_expt is None:
72-
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
68+
# tokens per slice
69+
if ragged_metadata is None:
70+
slice_size = m
71+
elif ragged_metadata.expected_slice_size is None:
72+
slice_size = max(1, m // ragged_metadata.n_slices)
7373
else:
74-
tokens_per_expt = routing_data.expected_tokens_per_expt
74+
slice_size = ragged_metadata.expected_slice_size
7575
# pid swizzling
7676
group_m = 8
7777
xcd_swizzle = 1
@@ -81,7 +81,7 @@ def make_default_opt_flags_intel(
8181
elif enforce_bitwise_invariance:
8282
block_m = 128
8383
else:
84-
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
84+
block_m = max(16, min(triton.next_power_of_2(slice_size), 128))
8585
# block n
8686
block_n = opt_flags_intel.compute_block_n(n)
8787
# is_persistent

python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_intel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
def compute_grid_size(routing_data, m, n, block_m, block_n):
66
if routing_data is not None:
7-
grid_m = routing_data.n_blocks(m, block_m)
7+
grid_m = routing_data.n_blocks(routing_data.n_slices, m, block_m)
88
else:
99
grid_m = triton.cdiv(m, block_m)
1010
grid_n = (n + block_n - 1) // block_n
@@ -19,7 +19,7 @@ def compute_block_n(n: int):
1919
def compute_block_k(k: int | None, is_persistent: bool, precision_config):
2020
if k is not None:
2121
block_k = max(32, min(128, triton.next_power_of_2(k)))
22-
has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None
22+
has_mx_weight_scale = precision_config is not None and precision_config.b_mx_scale is not None
2323
if is_persistent and has_mx_weight_scale:
2424
block_k = min(block_k, 128)
2525
return block_k
Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +0,0 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5074
2-
tests/test_matmul.py::test_op[False-False-False-True-None0-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
3-
tests/test_matmul.py::test_op[False-False-False-True-None0-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
4-
tests/test_matmul.py::test_op[False-False-False-True-None1-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
5-
tests/test_matmul.py::test_op[False-False-False-True-None1-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
6-
tests/test_matmul.py::test_op[False-False-True-True-None0-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
7-
tests/test_matmul.py::test_op[False-False-True-True-None0-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
8-
tests/test_matmul.py::test_op[False-False-True-True-None1-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
9-
tests/test_matmul.py::test_op[False-False-True-True-None1-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
10-
tests/test_matmul.py::test_op[False-True-False-True-None0-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
11-
tests/test_matmul.py::test_op[False-True-False-True-None0-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
12-
tests/test_matmul.py::test_op[False-True-False-True-None1-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
13-
tests/test_matmul.py::test_op[False-True-False-True-None1-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
14-
tests/test_matmul.py::test_op[False-True-True-True-None0-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
15-
tests/test_matmul.py::test_op[False-True-True-True-None0-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
16-
tests/test_matmul.py::test_op[False-True-True-True-None1-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
17-
tests/test_matmul.py::test_op[False-True-True-True-None1-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +0,0 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5074
2-
tests/test_matmul.py::test_op[False-False-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
3-
tests/test_matmul.py::test_op[False-False-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
4-
tests/test_matmul.py::test_op[False-False-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False]
5-
tests/test_matmul.py::test_op[False-False-False-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
6-
tests/test_matmul.py::test_op[False-False-False-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
7-
tests/test_matmul.py::test_op[False-False-False-True-False-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
8-
tests/test_matmul.py::test_op[False-False-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
9-
tests/test_matmul.py::test_op[False-False-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
10-
tests/test_matmul.py::test_op[False-False-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False]
11-
tests/test_matmul.py::test_op[False-False-False-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
12-
tests/test_matmul.py::test_op[False-False-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
13-
tests/test_matmul.py::test_op[False-False-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
14-
tests/test_matmul.py::test_op[False-False-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
15-
tests/test_matmul.py::test_op[False-False-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
16-
tests/test_matmul.py::test_op[False-False-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False]
17-
tests/test_matmul.py::test_op[False-False-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
18-
tests/test_matmul.py::test_op[False-False-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
19-
tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
20-
tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
21-
tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
22-
tests/test_matmul.py::test_op[False-False-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False]
23-
tests/test_matmul.py::test_op[False-False-True-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
24-
tests/test_matmul.py::test_op[False-False-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
25-
tests/test_matmul.py::test_op[False-False-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
26-
tests/test_matmul.py::test_op[False-True-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
27-
tests/test_matmul.py::test_op[False-True-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
28-
tests/test_matmul.py::test_op[False-True-False-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False]
29-
tests/test_matmul.py::test_op[False-True-False-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
30-
tests/test_matmul.py::test_op[False-True-False-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
31-
tests/test_matmul.py::test_op[False-True-False-True-False-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
32-
tests/test_matmul.py::test_op[False-True-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
33-
tests/test_matmul.py::test_op[False-True-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
34-
tests/test_matmul.py::test_op[False-True-False-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False]
35-
tests/test_matmul.py::test_op[False-True-False-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
36-
tests/test_matmul.py::test_op[False-True-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
37-
tests/test_matmul.py::test_op[False-True-False-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
38-
tests/test_matmul.py::test_op[False-True-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
39-
tests/test_matmul.py::test_op[False-True-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
40-
tests/test_matmul.py::test_op[False-True-True-False-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False]
41-
tests/test_matmul.py::test_op[False-True-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
42-
tests/test_matmul.py::test_op[False-True-True-False-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
43-
tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
44-
tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
45-
tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]
46-
tests/test_matmul.py::test_op[False-True-True-True-False-None-128-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-False]
47-
tests/test_matmul.py::test_op[False-True-True-True-False-None-16-1000-400-400-ragged-float8_e4m3fn-float8_e4m3fn-3-1-1-False-None-False-False-False-True]
48-
tests/test_matmul.py::test_op[False-True-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-False-None-False-False-False-True]
49-
tests/test_matmul.py::test_op[False-True-True-True-False-None-16-1000-704-800-ragged-mxfloat8_e4m3fn-mxfloat4_e2m1-8-2-9-True-None-False-False-False-True]

0 commit comments

Comments
 (0)