|
| 1 | +import torch |
| 2 | +import time |
| 3 | +import pytest |
| 4 | +import numpy as np |
| 5 | +from lightllm.common.fused_moe.softmax_topk import softmax_topk |
| 6 | +from lightllm.utils.log_utils import init_logger |
| 7 | + |
| 8 | +logger = init_logger(__name__) |
| 9 | + |
| 10 | + |
| 11 | +def benchmark(M, N, K, renorm, runs): |
| 12 | + import sgl_kernel as sgl_ops |
| 13 | + |
| 14 | + gating = torch.randn(M, N, device="cuda", dtype=torch.float32) |
| 15 | + torch.cuda.synchronize() |
| 16 | + |
| 17 | + # 1. SGL kernel |
| 18 | + sgl_vals = torch.empty((M, K), dtype=torch.float32, device="cuda") |
| 19 | + sgl_ids = torch.empty((M, K), dtype=torch.int32, device="cuda") |
| 20 | + # Warm-up |
| 21 | + sgl_ops.topk_softmax(sgl_vals, sgl_ids, torch.empty_like(sgl_ids), gating) |
| 22 | + torch.cuda.synchronize() |
| 23 | + start = torch.cuda.Event(True) |
| 24 | + end = torch.cuda.Event(True) |
| 25 | + start.record() |
| 26 | + for _ in range(runs): |
| 27 | + sgl_ops.topk_softmax(sgl_vals, sgl_ids, torch.empty_like(sgl_ids), gating) |
| 28 | + if renorm: |
| 29 | + sgl_vals.div_(sgl_vals.sum(-1, keepdim=True).clamp_min(1e-8)) |
| 30 | + |
| 31 | + end.record() |
| 32 | + torch.cuda.synchronize() |
| 33 | + t_sgl = start.elapsed_time(end) / runs |
| 34 | + |
| 35 | + # 2. Triton kernel |
| 36 | + t0 = torch.cuda.Event(True) |
| 37 | + t1 = torch.cuda.Event(True) |
| 38 | + # Warm-up |
| 39 | + softmax_topk(gating, K) |
| 40 | + torch.cuda.synchronize() |
| 41 | + t0.record() |
| 42 | + for _ in range(runs): |
| 43 | + triton_vals, triton_ids = softmax_topk(gating, K, renorm) |
| 44 | + t1.record() |
| 45 | + torch.cuda.synchronize() |
| 46 | + t_triton = t0.elapsed_time(t1) / runs |
| 47 | + |
| 48 | + # 3. Native PyTorch |
| 49 | + _ = torch.softmax(gating, dim=-1) |
| 50 | + _, _ = torch.topk(_, K, dim=-1) |
| 51 | + torch.cuda.synchronize() |
| 52 | + |
| 53 | + start, end = torch.cuda.Event(True), torch.cuda.Event(True) |
| 54 | + start.record() |
| 55 | + for _ in range(runs): |
| 56 | + probs = torch.softmax(gating, dim=-1) |
| 57 | + torch_vals, torch_ids = torch.topk(probs, K, dim=-1) |
| 58 | + if renorm: |
| 59 | + torch_vals.div_(torch_vals.sum(-1, keepdim=True).clamp_min(1e-8)) |
| 60 | + end.record() |
| 61 | + torch.cuda.synchronize() |
| 62 | + t_torch = start.elapsed_time(end) / runs |
| 63 | + |
| 64 | + # Compare indices and weights |
| 65 | + # Count mismatches of ordered indices |
| 66 | + diff_sgl_triton_ids = (sgl_ids != triton_ids).sum().item() |
| 67 | + diff_torch_triton_ids = (torch_ids != triton_ids).sum().item() |
| 68 | + # Max absolute difference of weights aligned by position |
| 69 | + max_err_triton_torch = (triton_vals - torch_vals).abs().max().item() |
| 70 | + max_err_triton_torch_sgl = (sgl_vals - torch_vals).abs().max().item() |
| 71 | + max_err_triton_sgl = (triton_vals - sgl_vals).abs().max().item() |
| 72 | + |
| 73 | + assert diff_sgl_triton_ids == 0, f"Mismatch SGL vs Triton ids: {diff_sgl_triton_ids}" |
| 74 | + assert diff_torch_triton_ids == 0, f"Mismatch Torch vs Triton ids: {diff_torch_triton_ids}" |
| 75 | + assert max_err_triton_torch < 1e-3, f"Max err Triton vs Torch: {max_err_triton_torch}" |
| 76 | + assert max_err_triton_torch_sgl < 1e-3, f"Max err Triton vs SGL: {max_err_triton_torch_sgl}" |
| 77 | + assert max_err_triton_sgl < 1e-3, f"Max err Torch vs SGL: {max_err_triton_sgl}" |
| 78 | + |
| 79 | + results = { |
| 80 | + "time_sgl": t_sgl, |
| 81 | + "time_triton": t_triton, |
| 82 | + "time_torch": t_torch, |
| 83 | + "mismatch_sgl_triton_ids": diff_sgl_triton_ids, |
| 84 | + "mismatch_torch_triton_ids": diff_torch_triton_ids, |
| 85 | + "max_err_triton_torch": max_err_triton_torch, |
| 86 | + "max_err_triton_sgl": max_err_triton_sgl, |
| 87 | + "max_err_triton_torch_sgl": max_err_triton_torch_sgl, |
| 88 | + "sgl_ids": sgl_ids, |
| 89 | + "triton_ids": triton_ids, |
| 90 | + "torch_ids": torch_ids, |
| 91 | + "sgl_vals": sgl_vals, |
| 92 | + "triton_vals": triton_vals, |
| 93 | + "torch_vals": torch_vals, |
| 94 | + } |
| 95 | + return results |
| 96 | + |
| 97 | + |
| 98 | +def test_softmax_topk(): |
| 99 | + M, N, K = 8192, 1024, 8 |
| 100 | + res = benchmark(M, N, K, False, 1000) |
| 101 | + print(f"SGL time: {res['time_sgl']:.6f}ms") |
| 102 | + print(f"Triton time: {res['time_triton']:.6f}ms") |
| 103 | + print(f"PyTorch time: {res['time_torch']:.6f}ms") |
| 104 | + print("Mismatch SGL vs Triton ids:", res["mismatch_sgl_triton_ids"]) |
| 105 | + print("Mismatch Torch vs Triton ids:", res["mismatch_torch_triton_ids"]) |
| 106 | + print("Max err Triton vs Torch :", res["max_err_triton_torch"]) |
| 107 | + print("Max err Triton vs SGL :", res["max_err_triton_sgl"]) |
| 108 | + print("Max err Torch vs SGL :", res["max_err_triton_torch_sgl"]) |
| 109 | + benchmark(M, N, K, True, 10) |
| 110 | + benchmark(M, 256, 5, True, 10) |
| 111 | + benchmark(M, 127, 5, True, 10) |
| 112 | + |
| 113 | + |
| 114 | +if __name__ == "__main__": |
| 115 | + pytest.main() |
0 commit comments