|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | + |
| 5 | + |
| 6 | +@triton.jit |
| 7 | +def softmax_topk_kernel( |
| 8 | + topk_weights_ptr, |
| 9 | + topk_indices_ptr, |
| 10 | + gating_output_ptr, |
| 11 | + input_row_stride, |
| 12 | + output_weights_row_stride, |
| 13 | + output_indices_row_stride, |
| 14 | + n_rows, |
| 15 | + n_cols, |
| 16 | + BLOCK_SIZE: tl.constexpr, |
| 17 | + top_k: tl.constexpr, |
| 18 | +): |
| 19 | + row_idx = tl.program_id(0) |
| 20 | + |
| 21 | + row_input_ptr = gating_output_ptr + row_idx * input_row_stride |
| 22 | + row_weights_ptr = topk_weights_ptr + row_idx * output_weights_row_stride |
| 23 | + row_indices_ptr = topk_indices_ptr + row_idx * output_indices_row_stride |
| 24 | + |
| 25 | + offsets = tl.arange(0, BLOCK_SIZE) |
| 26 | + mask = offsets < n_cols |
| 27 | + |
| 28 | + values = tl.load(row_input_ptr + offsets, mask=mask, other=-float("inf")) |
| 29 | + |
| 30 | + current_max = tl.max(values, axis=0) |
| 31 | + values = values - current_max |
| 32 | + numerators = tl.exp(values) |
| 33 | + denom = tl.sum(numerators, axis=0) |
| 34 | + |
| 35 | + for i in range(top_k): |
| 36 | + logit = tl.max(values, axis=0) |
| 37 | + idx = tl.argmax(values, axis=0) |
| 38 | + |
| 39 | + prob = tl.exp(logit) / denom |
| 40 | + |
| 41 | + lane0 = offsets == 0 |
| 42 | + ptr_w = row_weights_ptr + i + offsets * 0 |
| 43 | + ptr_i = row_indices_ptr + i + offsets * 0 |
| 44 | + tl.store(ptr_w, tl.where(lane0, prob, 0.0), mask=lane0) |
| 45 | + tl.store(ptr_i, tl.where(lane0, idx, 0), mask=lane0) |
| 46 | + |
| 47 | + values = tl.where(offsets == idx, -float("inf"), values) |
| 48 | + |
| 49 | + |
| 50 | +def softmax_topk(gating_output: torch.Tensor, topk: int): |
| 51 | + assert gating_output.dim() == 2, "The dim of gating_output must be 2." |
| 52 | + num_tokens, num_experts = gating_output.shape |
| 53 | + device = gating_output.device |
| 54 | + |
| 55 | + if gating_output.dtype != torch.float32: |
| 56 | + gating_output = gating_output.to(torch.float32) |
| 57 | + |
| 58 | + topk_vals = torch.empty((num_tokens, topk), dtype=torch.float32, device=device) |
| 59 | + topk_idxs = torch.empty((num_tokens, topk), dtype=torch.int32, device=device) |
| 60 | + |
| 61 | + BLOCK_SIZE = triton.next_power_of_2(num_experts) |
| 62 | + |
| 63 | + grid = (num_tokens,) |
| 64 | + softmax_topk_kernel[grid]( |
| 65 | + topk_vals, |
| 66 | + topk_idxs, |
| 67 | + gating_output, |
| 68 | + gating_output.stride(0), |
| 69 | + topk_vals.stride(0), |
| 70 | + topk_idxs.stride(0), |
| 71 | + num_tokens, |
| 72 | + num_experts, |
| 73 | + BLOCK_SIZE=BLOCK_SIZE, |
| 74 | + top_k=topk, |
| 75 | + num_warps=8, |
| 76 | + ) |
| 77 | + return topk_vals, topk_idxs |
| 78 | + |
| 79 | + |
| 80 | +import sgl_kernel as sgl_ops |
| 81 | + |
| 82 | + |
| 83 | +# |
| 84 | +def benchmark(M, N, K): |
| 85 | + gating = torch.randn(M, N, device="cuda", dtype=torch.float32) |
| 86 | + torch.cuda.synchronize() |
| 87 | + |
| 88 | + # 1. SGL kernel |
| 89 | + sgl_vals = torch.empty((M, K), dtype=torch.float32, device="cuda") |
| 90 | + sgl_ids = torch.empty((M, K), dtype=torch.int32, device="cuda") |
| 91 | + # Warm-up |
| 92 | + sgl_ops.topk_softmax(sgl_vals, sgl_ids, torch.empty_like(sgl_ids), gating) |
| 93 | + torch.cuda.synchronize() |
| 94 | + start = torch.cuda.Event(True) |
| 95 | + end = torch.cuda.Event(True) |
| 96 | + start.record() |
| 97 | + sgl_ops.topk_softmax(sgl_vals, sgl_ids, torch.empty_like(sgl_ids), gating) |
| 98 | + end.record() |
| 99 | + torch.cuda.synchronize() |
| 100 | + t_sgl = start.elapsed_time(end) / 1000.0 |
| 101 | + |
| 102 | + # 2. Triton kernel |
| 103 | + t0 = torch.cuda.Event(True) |
| 104 | + t1 = torch.cuda.Event(True) |
| 105 | + # Warm-up |
| 106 | + softmax_topk(gating, K) |
| 107 | + t0.record() |
| 108 | + triton_vals, triton_ids = softmax_topk(gating, K) |
| 109 | + t1.record() |
| 110 | + torch.cuda.synchronize() |
| 111 | + t_triton = t0.elapsed_time(t1) / 1000.0 |
| 112 | + |
| 113 | + # 3. Native PyTorch |
| 114 | + start, end = torch.cuda.Event(True), torch.cuda.Event(True) |
| 115 | + start.record() |
| 116 | + probs = torch.softmax(gating, dim=-1) |
| 117 | + torch_vals, torch_ids = torch.topk(probs, K, dim=-1) |
| 118 | + end.record() |
| 119 | + torch.cuda.synchronize() |
| 120 | + t_torch = start.elapsed_time(end) / 1000.0 |
| 121 | + |
| 122 | + # Compare indices and weights |
| 123 | + # Count mismatches of ordered indices |
| 124 | + diff_sgl_triton_ids = (sgl_ids != triton_ids).sum().item() |
| 125 | + diff_torch_triton_ids = (torch_ids != triton_ids).sum().item() |
| 126 | + # Max absolute difference of weights aligned by position |
| 127 | + max_err_triton_torch = (triton_vals - torch_vals).abs().max().item() |
| 128 | + max_err_triton_torch_sgl = (sgl_vals - torch_vals).abs().max().item() |
| 129 | + max_err_triton_sgl = (triton_vals - sgl_vals).abs().max().item() |
| 130 | + |
| 131 | + results = { |
| 132 | + "time_sgl": t_sgl, |
| 133 | + "time_triton": t_triton, |
| 134 | + "time_torch": t_torch, |
| 135 | + "mismatch_sgl_triton_ids": diff_sgl_triton_ids, |
| 136 | + "mismatch_torch_triton_ids": diff_torch_triton_ids, |
| 137 | + "max_err_triton_torch": max_err_triton_torch, |
| 138 | + "max_err_triton_sgl": max_err_triton_sgl, |
| 139 | + "max_err_triton_torch_sgl": max_err_triton_torch_sgl, |
| 140 | + "sgl_ids": sgl_ids, |
| 141 | + "triton_ids": triton_ids, |
| 142 | + "torch_ids": torch_ids, |
| 143 | + "sgl_vals": sgl_vals, |
| 144 | + "triton_vals": triton_vals, |
| 145 | + "torch_vals": torch_vals, |
| 146 | + } |
| 147 | + return results |
| 148 | + |
| 149 | + |
| 150 | +if __name__ == "__main__": |
| 151 | + # Example: 8192 tokens, 1024 experts, Top-4 |
| 152 | + M, N, K = 8192, 1024, 4 |
| 153 | + res = benchmark(M, N, K) |
| 154 | + print(f"SGL time: {res['time_sgl']:.6f}s") |
| 155 | + print(f"Triton time: {res['time_triton']:.6f}s") |
| 156 | + print(f"PyTorch time: {res['time_torch']:.6f}s") |
| 157 | + print("Mismatch SGL vs Triton ids:", res["mismatch_sgl_triton_ids"]) |
| 158 | + print("Mismatch Torch vs Triton ids:", res["mismatch_torch_triton_ids"]) |
| 159 | + print("Max err Triton vs Torch :", res["max_err_triton_torch"]) |
| 160 | + print("Max err Triton vs SGL :", res["max_err_triton_sgl"]) |
| 161 | + print("Max err Torch vs SGL :", res["max_err_triton_torch_sgl"]) |
0 commit comments