|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | +from triton_bench.routing import routing, routing_torch |
| 4 | +from triton_bench.testing import assert_close |
| 5 | +from triton_bench.matmul_ogs_details.metadata import compute_metadata |
| 6 | +from triton_bench.testing import assert_equal |
| 7 | + |
| 8 | + |
| 9 | +def init_data(n_tokens, n_expts_tot, dtype=torch.float16): |
| 10 | + dev = "cuda" |
| 11 | + # the reference implementation and the triton implementation do not tie-break experts the same way |
| 12 | + randbits = [torch.randperm(n_expts_tot) for _ in range(n_tokens)] |
| 13 | + x = [(-1)**i * ((16384 + ((i * 512) % 4096) + bits).to(torch.int16).view(dtype)) for i, bits in enumerate(randbits)] |
| 14 | + return torch.stack(x).to(device=dev) |
| 15 | + |
| 16 | + |
| 17 | +def ref_expt_data(routing_data, n_gates, block_m): |
| 18 | + hist = routing_data.expt_hist |
| 19 | + n_expts_tot = routing_data.n_expts_tot |
| 20 | + blks = (hist + block_m - 1) // block_m # matmul blocks needed |
| 21 | + tsum = torch.cumsum(hist, dim=0) # prefix sum of tokens |
| 22 | + bsum = torch.cumsum(blks, dim=0) # prefix sum of blocks |
| 23 | + # Get the max number of matmul blocks of size d_tile needed (and is launched with). |
| 24 | + # This assumes the worst distribution of all experts with one token except for one that has the rest. |
| 25 | + if n_gates <= n_expts_tot: |
| 26 | + grid_m = n_gates |
| 27 | + else: |
| 28 | + # ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1 |
| 29 | + # ceil_div(x, y): -(-x // y) |
| 30 | + grid_m = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // block_m) |
| 31 | + bloc_data = -torch.ones(grid_m, dtype=torch.int32) |
| 32 | + # compute data required to drive ragged batch matmul |
| 33 | + for e in range(n_expts_tot): |
| 34 | + offset = bsum[e - 1] if e else 0 |
| 35 | + for b in range(blks[e]): |
| 36 | + bloc_data[offset + b] = (b << 16) + e |
| 37 | + |
| 38 | + expt_data = torch.zeros(n_expts_tot * 3 + 2 + grid_m, dtype=torch.int32, device=hist.device) |
| 39 | + expt_data[:n_expts_tot] = routing_data.expt_hist |
| 40 | + expt_data[n_expts_tot + 1:n_expts_tot * 2 + 1] = tsum |
| 41 | + expt_data[n_expts_tot * 2 + 2:n_expts_tot * 3 + 2] = bsum |
| 42 | + expt_data[n_expts_tot * 3 + 2:] = bloc_data |
| 43 | + return expt_data |
| 44 | + |
| 45 | + |
| 46 | +@pytest.mark.parametrize("n_tokens", [371, 255, 256, 8192, 1023, 1024]) |
| 47 | +@pytest.mark.parametrize("n_expts_tot, n_expts_act", [(128, 4)]) |
| 48 | +@pytest.mark.parametrize("block_m", [64, 128]) |
| 49 | +def test_op(n_tokens, n_expts_tot, n_expts_act, block_m): |
| 50 | + torch.manual_seed(2) |
| 51 | + tri_logits = init_data(n_tokens, n_expts_tot).detach() |
| 52 | + ref_logits = tri_logits.clone() |
| 53 | + ref_routing_data, ref_gather, ref_scatter = routing_torch(ref_logits, n_expts_act) |
| 54 | + tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act) |
| 55 | + ref_metadata = ref_expt_data(ref_routing_data, n_tokens * n_expts_act, block_m) |
| 56 | + tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m).buffer |
| 57 | + |
| 58 | + assert_close(ref_routing_data.gate_scal, tri_routing_data.gate_scal, 2e-2, 4e-3) |
| 59 | + assert_equal(ref_routing_data.expt_hist, tri_routing_data.expt_hist) |
| 60 | + assert_equal(ref_metadata, tri_metadata) |
| 61 | + assert ref_routing_data.n_expts_tot == ref_routing_data.n_expts_tot |
| 62 | + assert ref_routing_data.n_expts_act == ref_routing_data.n_expts_act |
| 63 | + |
| 64 | + def _assert_indx_equal(ref, tri): |
| 65 | + assert_equal(ref, tri[:len(ref)]) |
| 66 | + assert torch.all(tri[len(ref):] == -1) |
| 67 | + |
| 68 | + _assert_indx_equal(ref_gather.src_indx, tri_gather.src_indx) |
| 69 | + _assert_indx_equal(ref_gather.dst_indx, tri_gather.dst_indx) |
| 70 | + _assert_indx_equal(ref_scatter.src_indx, tri_scatter.src_indx) |
| 71 | + _assert_indx_equal(ref_scatter.dst_indx, tri_scatter.dst_indx) |
| 72 | + |
| 73 | + |
| 74 | +def bench_routing(): |
| 75 | + import triton.profiler as proton |
| 76 | + n_tokens = 2048 |
| 77 | + block_m = 128 |
| 78 | + n_expts_tot, n_expts_act = 128, 4 |
| 79 | + tri_logits = init_data(n_tokens, n_expts_tot) |
| 80 | + proton.start("routing") |
| 81 | + proton.activate() |
| 82 | + for i in range(100): |
| 83 | + tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act) |
| 84 | + tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m) |
| 85 | + proton.finalize() |
| 86 | + |
| 87 | + |
| 88 | +if __name__ == "__main__": |
| 89 | + bench_routing() |
0 commit comments