Skip to content

Commit e33b846

Browse files
wdziurdzanmyachev
authored andcommitted
Align behavior with CUDA/HIP: skip test_matmul when swiglu_opts is not None and do_gamma is set
Signed-off-by: Witold Dziurdz <[email protected]> (cherry picked from commit 1479afd)
1 parent af6e538 commit e33b846

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# testing utilities
1717
from triton_kernels.testing import assert_close, make_random_tensor
1818
# target-specific utilities
19-
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4
19+
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4, is_xpu
2020
from triton_kernels.swiglu import swiglu, swiglu_fn
2121
from triton_kernels.swiglu import PrecisionConfig as SwiGLUPrecisionConfig
2222

@@ -243,6 +243,10 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
243243
if split_k is not None and split_k > 1:
244244
pytest.skip("splitK hasn't been fully tested on AMD GPU.")
245245

246+
elif is_xpu():
247+
if swiglu_opts is not None and do_gamma:
248+
pytest.xfail("NYI: swiglu and gamma not supported together")
249+
246250
if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3():
247251
pytest.xfail("float8_e4m3fnuz only tested on AMD CDNA3 Platform")
248252

python/triton_kernels/triton_kernels/matmul_details/opt_flags.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def make_default_opt_flags_intel(
5454
m,
5555
n,
5656
k,
57-
routing_data,
57+
ragged_metadata,
5858
can_use_persistent_tma,
5959
can_use_split_k,
6060
enforce_bitwise_invariance,
@@ -65,13 +65,13 @@ def make_default_opt_flags_intel(
6565
):
6666
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "epilogue_subtile", "num_stages", "max_allowable_mn"]
6767
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
68-
# tokens per expert
69-
if routing_data is None:
70-
tokens_per_expt = m
71-
elif routing_data.expected_tokens_per_expt is None:
72-
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
68+
# tokens per slice
69+
if ragged_metadata is None:
70+
slice_size = m
71+
elif ragged_metadata.expected_slice_size is None:
72+
slice_size = max(1, m // ragged_metadata.n_slices)
7373
else:
74-
tokens_per_expt = routing_data.expected_tokens_per_expt
74+
slice_size = ragged_metadata.expected_slice_size
7575
# pid swizzling
7676
group_m = 8
7777
xcd_swizzle = 1
@@ -81,7 +81,7 @@ def make_default_opt_flags_intel(
8181
elif enforce_bitwise_invariance:
8282
block_m = 128
8383
else:
84-
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
84+
block_m = max(16, min(triton.next_power_of_2(slice_size), 128))
8585
# block n
8686
block_n = opt_flags_intel.compute_block_n(n)
8787
# is_persistent

python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_intel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
def compute_grid_size(routing_data, m, n, block_m, block_n):
66
if routing_data is not None:
7-
grid_m = routing_data.n_blocks(m, block_m)
7+
grid_m = routing_data.n_blocks(routing_data.n_slices, m, block_m)
88
else:
99
grid_m = triton.cdiv(m, block_m)
1010
grid_n = (n + block_n - 1) // block_n
@@ -19,7 +19,7 @@ def compute_block_n(n: int):
1919
def compute_block_k(k: int | None, is_persistent: bool, precision_config):
2020
if k is not None:
2121
block_k = max(32, min(128, triton.next_power_of_2(k)))
22-
has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None
22+
has_mx_weight_scale = precision_config is not None and precision_config.b_mx_scale is not None
2323
if is_persistent and has_mx_weight_scale:
2424
block_k = min(block_k, 128)
2525
return block_k

0 commit comments

Comments
 (0)