Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,31 @@
import torch

from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
silu_mul_fp8_quant_deep_gemm,
silu_mul_fp8_quant_deep_gemm_cuda,
)
from vllm.model_executor.layers.fused_moe.old_batched_deep_gemm_moe import (
silu_mul_fp8_quant_deep_gemm as gold,
)
from vllm.platforms import current_platform


def benchmark(E, T, H, G=128, runs=50):
def benchmark(k, E, T, H, G=128, runs=100):
current_platform.seed_everything(42)
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous()
tokens_per_expert = torch.randint(
T // 2, T, size=(E,), dtype=torch.int32, device="cuda"
)

# Warmup
for _ in range(10):
silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
for _ in range(20):
k(y, tokens_per_expert, group_size=G)
torch.cuda.synchronize()

# Benchmark
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(runs):
silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
k(y, tokens_per_expert, group_size=G)
torch.cuda.synchronize()

avg_time = (time.perf_counter() - start) / runs * 1000
Expand All @@ -52,26 +55,46 @@ def benchmark(E, T, H, G=128, runs=50):


configs = [
(8, 32, 1024),
(16, 64, 2048),
(32, 128, 4096),
# DeepSeekV3 Configs
(256, 16, 7168),
(256, 32, 7168),
(256, 64, 7168),
(256, 128, 7168),
(256, 256, 7168),
(256, 512, 7168),
(256, 1024, 7168),
(8, 16, 7168),
(8, 32, 7168),
(8, 64, 7168),
(8, 128, 7168),
(8, 256, 7168),
(8, 512, 7168),
(8, 1024, 7168),
(9, 16, 7168),
(9, 32, 7168),
(9, 64, 7168),
(9, 128, 7168),
(9, 256, 7168),
(9, 512, 7168),
(9, 1024, 7168),
# (16, 64, 2048),
# (32, 128, 4096),
# (256, 16, 7168),
# (256, 32, 7168),
# (256, 64, 7168),
# (256, 128, 7168),
# (256, 256, 7168),
# (256, 512, 7168),
# (256, 1024, 7168),
]

print(f"GPU: {torch.cuda.get_device_name()}")

print(f"GPU: {torch.cuda.get_device_name()} CUDA Kernel")
print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
print("-" * 50)

for E, T, H in configs:
time_ms, gflops, gbps = benchmark(silu_mul_fp8_quant_deep_gemm_cuda, E, T, H)
print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")


print(f"GPU: {torch.cuda.get_device_name()} Baseline")
print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
print("-" * 50)

for E, T, H in configs:
try:
time_ms, gflops, gbps = benchmark(E, T, H)
print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
except Exception:
print(f"E={E:3d},T={T:4d},H={H:4d} FAILED")
time_ms, gflops, gbps = benchmark(gold, E, T, H)
print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
9 changes: 8 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& input_global_scale);
#endif
void silu_mul_fp8_quant_deep_gemm_cuda(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E)
at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
int64_t group_size, double eps, double fp8_min, double fp8_max,
bool use_ue8m0);

void mul_and_silu(torch::Tensor& out, torch::Tensor& input);

Expand Down Expand Up @@ -354,4 +361,4 @@ void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t quant_level, bool cast_bf2half = false);
int64_t qr_max_size();
#endif
#endif
Loading
Loading