|
31 | 31 | is_hip, |
32 | 32 | is_hip_cdna, |
33 | 33 | is_hip_mi200, |
| 34 | + is_hip_mi300, |
34 | 35 | is_xpu, |
35 | 36 | get_arch, |
36 | 37 | torch_float8_dtypes, |
@@ -3414,8 +3415,8 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack |
3414 | 3415 | if is_hip(): |
3415 | 3416 | if not is_hip_cdna(): |
3416 | 3417 | pytest.skip("scaled_dot only implemented for HIP CDNA") |
3417 | | - if (type_a not in ["e2m1", "e5m2"]) or (type_b not in ["e2m1", "e5m2", "bf16"]): |
3418 | | - pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP") |
| 3418 | + if "e4m3" in (type_a, type_b) and not is_hip_mi300(): |
| 3419 | + pytest.skip(f"scaled_dot({type_a}, {type_b}) only implemented for MI300") |
3419 | 3420 | if mma == 16 and K == 64: |
3420 | 3421 | pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot") |
3421 | 3422 | if is_xpu(): |
@@ -6072,3 +6073,33 @@ def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr): |
6072 | 6073 | Z = torch.zeros_like(X) |
6073 | 6074 | sanitize_cumsum_kernel[(1, )](Z, X, BLOCK=BLOCK) |
6074 | 6075 | torch.testing.assert_close(Z, X.cumsum(0).to(torch.int32)) |
| 6076 | + |
| 6077 | + |
| 6078 | +# stress test slice layout usages in reductions. |
| 6079 | +@pytest.mark.parametrize("in_shape, perm, red_dims", [ |
| 6080 | + ((4, 32, 32, 4, 2), [2, 1, 0, 3, 4], [3, 1, 0]), |
| 6081 | + ((8, 2, 32, 4, 16), [4, 0, 1, 3, 2], [0, 2, 0]), |
| 6082 | +]) |
| 6083 | +def test_chained_reductions(in_shape, perm, red_dims, device): |
| 6084 | + |
| 6085 | + @triton.jit |
| 6086 | + def kernel(In, Out, # |
| 6087 | + dim_0: tl.constexpr, dim_1: tl.constexpr, dim_2: tl.constexpr, dim_3: tl.constexpr, dim_4: tl.constexpr, |
| 6088 | + perm_0: tl.constexpr, perm_1: tl.constexpr, perm_2: tl.constexpr, perm_3: tl.constexpr, |
| 6089 | + perm_4: tl.constexpr, red_dim_0: tl.constexpr, red_dim_1: tl.constexpr, red_dim_2: tl.constexpr): |
| 6090 | + idx = tl.arange(0, dim_0 * dim_1 * dim_2 * dim_3 * dim_4) |
| 6091 | + idx = idx.reshape(dim_0, dim_1, dim_2, dim_3, dim_4) |
| 6092 | + vals = tl.load(In + idx) |
| 6093 | + vals = tl.permute(vals, [perm_0, perm_1, perm_2, perm_3, perm_4]) |
| 6094 | + r = tl.sum(tl.sum(tl.sum(vals, red_dim_0), red_dim_1), red_dim_2) |
| 6095 | + st_idx = tl.arange(0, r.shape[0] * r.shape[1]).reshape(r.shape) |
| 6096 | + tl.store(Out + st_idx, r) |
| 6097 | + |
| 6098 | + input = torch.randint(0, 1000, in_shape, device=device, dtype=torch.int32) |
| 6099 | + temp = torch.permute(input, perm).contiguous() |
| 6100 | + ref = torch.sum(torch.sum(torch.sum(temp, dim=red_dims[0]), dim=red_dims[1]), dim=red_dims[2]) |
| 6101 | + result = torch.empty_like(ref) |
| 6102 | + kernel[(1, )](input, result, input.shape[0], input.shape[1], input.shape[2], input.shape[3], input.shape[4], |
| 6103 | + perm[0], perm[1], perm[2], perm[3], perm[4], red_dims[0], red_dims[1], red_dims[2]) |
| 6104 | + |
| 6105 | + assert torch.all(ref == result) |
0 commit comments