Skip to content
Draft
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 75 additions & 18 deletions benchmarks/bench_cutlass_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import os
import sys
import time
Comment on lines +17 to +19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider removing the unused sys and time imports to maintain clean and readable code.

Suggested change
import os
import sys
import time
import json
import os
import torch


import torch
from torch.nn import functional as F
Expand All @@ -21,6 +25,7 @@
import flashinfer
import flashinfer.fused_moe as fused_moe
from flashinfer import fp4_quantize
from flashinfer.testing.utils import bench_kineto

BATCH_SIZES = [
1,
Expand All @@ -35,7 +40,9 @@
96,
128,
256,
384, # NOTE ADD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Remove this temporary comment, along with the blocks of commented-out code (e.g., lines 84-97, 215-227), to improve code clarity.

Suggested change
384, # NOTE ADD
384,

512,
768, # NOTE ADD
1024,
1536,
2048,
Expand All @@ -53,18 +60,40 @@
FP8_DTYPE = torch.float8_e4m3fn

test_configs = [
{
"hidden_size": 7168,
"num_experts": 256,
"top_k": 8,
"intermediate_size": 256,
},
{
"hidden_size": 7168,
"num_experts": 32,
"top_k": 8,
"intermediate_size": 2048,
},
# NOTE MODIFIED ADD
*[
{
"hidden_size": 7168,
"num_experts": num_experts,
"top_k": 8,
"intermediate_size": 2048,
}
for num_experts in [
288 // 1,
288 // 2,
288 // 4,
288 // 8,
288 // 16,
288 // 32,
# TODO support
# 288 // 48,
# 288 // 72,
]
],

# --- old ---
# {
# "hidden_size": 7168,
# "num_experts": 256,
# "top_k": 8,
# "intermediate_size": 256,
# },
# {
# "hidden_size": 7168,
# "num_experts": 32,
# "top_k": 8,
# "intermediate_size": 2048,
# },
]


Expand Down Expand Up @@ -182,7 +211,22 @@ def bench_cutlass_fused_moe(
input_sf=input_sf,
output=flash_output,
)
ms = do_bench(
# NOTE MODIFIED
# ms = do_bench(
# lambda: fused_moe.cutlass_fused_moe(
# hidden_states,
# selected_experts.to(torch.int),
# routing_weights,
# w1_q.contiguous().view(torch.long),
# w2_q.contiguous().view(torch.long),
# otype,
# quant_scales=quant_scales,
# input_sf=input_sf,
# output=flash_output,
# )
# )
trace_dir = os.environ.get("BENCH_KINETO_TRACE_DIR")
[time_gemm1, time_gemm2] = bench_kineto(
lambda: fused_moe.cutlass_fused_moe(
hidden_states,
selected_experts.to(torch.int),
Expand All @@ -193,12 +237,25 @@ def bench_cutlass_fused_moe(
quant_scales=quant_scales,
input_sf=input_sf,
output=flash_output,
)
),
kernel_names="cutlass13device_kernelINS_4gemm6kernel",
num_kernels_per_period=2,
trace_path=f"{trace_dir}/{time.time()}.trace.json.gz" if trace_dir else None,
)
print(
f"batch_size={batch_size}, num_experts={num_experts}, top_k={top_k}, intermediate_size={intermediate_size}"
)
print(f"execution time: {ms}ms")

# NOTE MODIFIED
print(f"MAIN_OUTPUT=" + json.dumps(dict(
batch_size=batch_size,
num_experts=num_experts,
top_k=top_k,
intermediate_size=intermediate_size,
time_gemm1_us=time_gemm1 * 1e6,
time_gemm2_us=time_gemm2 * 1e6,
)))
Comment on lines +247 to +254
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use a single f-string for the whole line to improve readability.

    print(f"MAIN_OUTPUT={\"batch_size\":{batch_size},\"num_experts\":{num_experts},\"top_k\":{top_k},\"intermediate_size\":{intermediate_size},\"time_gemm1_us\":{time_gemm1 * 1e6},\"time_gemm2_us\":{time_gemm2 * 1e6}}")

# print(
# f"batch_size={batch_size}, num_experts={num_experts}, top_k={top_k}, intermediate_size={intermediate_size}"
# )
# print(f"execution time: {ms}ms")


if __name__ == "__main__":
Expand Down
124 changes: 122 additions & 2 deletions flashinfer/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Tuple
import json
import os
import sys
import tempfile
from pathlib import Path
from typing import Tuple, Union, Optional
import torch.distributed as dist

import torch
from einops import rearrange, reduce, repeat
Expand Down Expand Up @@ -206,3 +211,118 @@ def dequantize_fp8(x, x_scale, scale_major_mode):
x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1 1 1 1")
out = rearrange(x * x_scale, "s0 s1 s2 t0 t1 t2 -> (s0 t0) (s1 t1) (s2 t2)")
return out


class empty_suppress:
def __enter__(self):
return self

def __exit__(self, *_):
pass


class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, 'w')
self.errnull_file = open(os.devnull, 'w')

self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()

self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())

self.old_stdout = sys.stdout
self.old_stderr = sys.stderr

os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)

sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self

def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr

os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)

os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)

self.outnull_file.close()
self.errnull_file.close()
Comment on lines +224 to +256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of suppress_stdout_stderr is not exception-safe, which can lead to resource leaks. Use a try...finally block or contextlib.ExitStack to ensure resources are always cleaned up correctly, even in case of errors.

Suggested change
class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, 'w')
self.errnull_file = open(os.devnull, 'w')
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()



# copy and modified from DeepGEMM and DeepEP
def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: Optional[str] = None, barrier_comm_profiling: bool = False, flush_l2: bool = True,
num_kernels_per_period: int = 1):
# Conflict with Nsight Systems
using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0))

# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
flush_l2_size = int(8e9 // 4)

# For some auto-tuning kernels with prints
fn()

# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
with suppress():
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
with profiler:
for i in range(2):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if barrier_comm_profiling:
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
lhs @ rhs
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
for _ in range(num_tests):
if flush_l2:
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
fn()

if not using_nsys:
profiler.step()

# Return 1 if using Nsight Systems
if using_nsys:
return 1

# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tuple = isinstance(kernel_names, tuple)
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
# print(f"prof_lines=\n" + "\n".join(prof_lines))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Remove the commented-out print statement to improve clarity.


# Save chrome traces
if trace_path is not None:
print(f"export_chrome_trace to {trace_path=}")
profiler.export_chrome_trace(trace_path)

kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])

kernel_durations = [None] * len(kernel_names)

# Expand the kernels by periods
with tempfile.NamedTemporaryFile(suffix='.json') as tmp:
profiler.export_chrome_trace(tmp.name)
profile_data = json.loads(Path(tmp.name).read_text())

for i, kernel_name in enumerate(kernel_names):
events = [event for event in profile_data['traceEvents'] if kernel_name in event['name']]
events = sorted(events, key=lambda event: event['ts'])
durations = [event['dur'] / 1e6 for event in events]
assert len(durations) % num_kernels_per_period == 0
num_kernel_patterns = len(durations) // num_kernels_per_period
kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns
for j in range(num_kernels_per_period)]

# Return execution durations
return kernel_durations if is_tuple else kernel_durations[0]
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def generate_build_meta(aot_build_meta: dict) -> None:
"einops",
"nvidia-nvshmem-cu12",
"nvidia-cudnn-cu12",
"nvidia-cudnn-frontend",
# NOTE MODIFIED rm
# "nvidia-cudnn-frontend",
Comment on lines +65 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Remove this temporary comment and the commented-out line.

]
generate_build_meta({})

Expand Down