|
1 | | -import torch |
2 | | -import pytest |
3 | | -import math |
4 | | -from triton_kernels.testing import assert_equal |
5 | | -from triton_kernels.tensor_details.layout import BlackwellMXScaleLayout, HopperMXScaleLayout, HopperMXValueLayout |
6 | | - |
7 | | - |
8 | | -@pytest.mark.parametrize( |
9 | | - "shape", |
10 | | - [ |
11 | | - (3, 4096, 1024), |
12 | | - (10, 254, 60), |
13 | | - (1, 320, 160), |
14 | | - (2, 16, 512), |
15 | | - (3, 2, 36), |
16 | | - ], |
17 | | -) |
18 | | -def test_mxfp_swizzle(shape: tuple[int, ...]): |
19 | | - """ |
20 | | - Test that unswizzle is the inverse of swizzle, after removing padding. |
21 | | - """ |
22 | | - x = torch.randn(shape, device="cuda") |
23 | | - layout = BlackwellMXScaleLayout(shape) |
24 | | - assert_equal(x, layout.unswizzle_data(layout.swizzle_data(x))) |
25 | | - |
26 | | - |
27 | | -@pytest.mark.parametrize("shape", [(16, 32), (16, 64), (32, 32), (32, 64), (64, 128), (128, 128)]) |
28 | | -@pytest.mark.parametrize("trans", [False, True]) |
29 | | -@pytest.mark.parametrize("op_idx", [0, 1]) |
30 | | -@pytest.mark.parametrize("mma_version", [2, 3]) |
31 | | -def test_swizzle_mxfp4_value(shape, trans, op_idx, mma_version): |
32 | | - x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda") |
33 | | - if trans: |
34 | | - x = x.mT |
35 | | - k_dim = 1 - op_idx |
36 | | - if x.shape[k_dim] < 32: |
37 | | - pytest.skip("Not enough elements along K") |
38 | | - layout = HopperMXValueLayout(x.shape, op_idx, mma_version) |
39 | | - res = layout.unswizzle_data(layout.swizzle_data(x)) |
40 | | - assert (res == x).all() |
41 | | - |
42 | | - |
43 | | -@pytest.mark.parametrize("num_warps", [4, 8]) |
44 | | -@pytest.mark.parametrize("shape", [(256, 64), (256, 128), (256, 256)]) |
45 | | -def test_swizzle_mxfp4_scale(shape, num_warps): |
46 | | - x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda") |
47 | | - layout = HopperMXScaleLayout(x.shape, num_warps=num_warps) |
48 | | - res = layout.unswizzle_data(layout.swizzle_data(x)) |
49 | | - assert (res[:shape[0], :shape[1]] == x).all() |
50 | | - |
51 | | - |
52 | | -def test_unswizzle_mxfp4_value_golden_value(): |
53 | | - shape = (16, 32) |
54 | | - x = torch.arange(math.prod(shape)).view(shape).to(torch.uint8) |
55 | | - layout = HopperMXValueLayout(x.shape, op_idx=1, mma_version=3) |
56 | | - res = layout.swizzle_data(x) |
57 | | - # Thread 0 |
58 | | - assert res[0, 0:16].tolist() == [0, 0, 4, 4, 8, 8, 12, 12, 16, 16, 20, 20, 24, 24, 28, 28] |
59 | | - # Thread 1 |
60 | | - assert res[0, 16:32].tolist() == [1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21, 25, 25, 29, 29] |
| 1 | +# TODO: add tests for non-layout parts of tensor class |
0 commit comments