Skip to content

Commit b655ab7

Browse files
authored
[Bench][AMD] Fix torch ref routing and enable CI (#7183)
- Fixed the failed tests disabled in triton-lang/triton#7166. - Skipped several failed tests to make the pipeline green for now. - Added bench tests to gfx950 and gfx942 CI pipelines.
1 parent b78022a commit b655ab7

File tree

5 files changed

+21
-3
lines changed

5 files changed

+21
-3
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,12 @@ jobs:
116116
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
117117
fi
118118
119+
# Run tests under triton/python/triton_kernels/tests/ on gfx950 and gfx942
120+
if [ "${{ matrix.runner[0] }}" = "amd-gfx950" ] || [ "${{ matrix.runner[0] }}" = "amd-gfx942" ]; then
121+
cd ../../triton_kernels/
122+
python3 -m pytest -s -n 12 tests/
123+
fi
124+
119125
- name: Run asan tests on AMD
120126
if: false
121127
run: |

python/triton_kernels/tests/test_matmul.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,10 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
253253
if split_k > 1:
254254
pytest.skip("splitK hasn't been fully tested on AMD GPU.")
255255

256+
if is_hip_cdna3() and ("float8_e4m3fn" in (weight_dtype_str, act_dtype_str)
257+
or "float8_e5m2" in (weight_dtype_str, act_dtype_str)):
258+
pytest.skip("float8_e4m3fn and float8_e5m2 hasn't been fully tested on AMD CDNA3 platform.")
259+
256260
if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3():
257261
pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform")
258262

python/triton_kernels/tests/test_mxfp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
upcast_from_mxfp_torch,
2323
)
2424
from triton_kernels.testing import assert_close, assert_equal
25+
from triton_kernels.target_info import is_hip, is_hip_cdna3
2526

2627

2728
def dtype_str_to_torch(dtype_str: str) -> torch.dtype:
@@ -142,6 +143,13 @@ def test_mxfp_casting(
142143
pytest.skip("Hopper swizzle not supported for tile not multiple of 64x128")
143144
if user_allocated_output and any([swizzle_value, swizzle_scale]):
144145
pytest.skip("User-allocated output not supported together with swizzling")
146+
if is_hip():
147+
if swizzle_value is not None or swizzle_scale is not None:
148+
pytest.skip("Other swizzling patterns are not supported by AMD GPU")
149+
if quant_dtype == 'float8_e4m3fn':
150+
pytest.skip("float8_e4m3fn cast hasn't been fully tested on AMD GPU")
151+
if quant_dtype == 'float8_e5m2' and is_hip_cdna3():
152+
pytest.skip("float8_e5m2 cast hasn't been fully tested on AMD CDNA3")
145153

146154
swizzle_axis = swizzle_axis if (swizzle_value or swizzle_scale) else None
147155
quant_torch_type = dtype_str_to_torch(quant_dtype)

python/triton_kernels/tests/test_routing.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from triton_kernels.routing import routing, routing_torch
44
from triton_kernels.testing import assert_close
55
from triton_kernels.testing import assert_equal
6-
from triton_kernels.target_info import is_hip
76

87

98
def init_data(n_tokens, n_expts_tot, dtype=torch.float32, device="cuda"):
@@ -19,7 +18,6 @@ def init_data(n_tokens, n_expts_tot, dtype=torch.float32, device="cuda"):
1918
@pytest.mark.parametrize("n_expts_tot, n_expts_act", [(128, 32), (1500, 8)])
2019
@pytest.mark.parametrize("use_expt_indx", [False, True])
2120
@pytest.mark.parametrize("sm_first", [True, False])
22-
@pytest.mark.skipif(is_hip(), reason="Tests are currently broken on AMD")
2321
def test_op(n_tokens_pad, n_tokens_raw, n_expts_tot, n_expts_act, sm_first, use_expt_indx, device):
2422
torch.manual_seed(2)
2523
if n_tokens_raw is None:

python/triton_kernels/triton_kernels/routing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ def compute_expt_data_torch(hist, n_expts_tot, n_gates):
271271
token_offs_raw = token_offs_raw.int()
272272
# maximum number of tiles for all values of `block_m` considered
273273
block_ms = [16, 32, 64, 128]
274+
if is_hip():
275+
block_ms.append(256)
274276
if n_gates <= n_expts_tot:
275277
max_n_tiles = n_gates
276278
else:
@@ -280,7 +282,7 @@ def compute_expt_data_torch(hist, n_expts_tot, n_gates):
280282
# fill up tile offset/infos for each block
281283
token_offs_pad = dict()
282284
block_pid_map = dict()
283-
for block_m in [16, 32, 64, 128]:
285+
for block_m in block_ms:
284286
n_tiles = (hist + block_m - 1) // block_m # matmul blocks needed
285287
token_offs_pad[block_m] = torch.cumsum(n_tiles, dim=0)
286288
token_offs_pad[block_m] = torch.cat((torch.zeros(1, device=device), token_offs_pad[block_m]))

0 commit comments

Comments
 (0)