-
Notifications
You must be signed in to change notification settings - Fork 471
[WIP] Add benchmark scripts to different moe gemms #1315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 17 commits
a5390df
94c6d27
64b0032
cfcc7ce
0e63dbe
72ec2c4
8caafde
5cbc7d0
ab252c7
8629008
4624913
9de0387
c89cf7c
adf2cb1
57eda4a
8b606b8
7b78322
8c53c59
c4c8a07
a19273c
a129862
3d37611
62eaa46
dbecb4a
91863e4
3baee7e
57a18b6
806cf11
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,10 @@ | |
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import json | ||
import os | ||
import sys | ||
import time | ||
|
||
import torch | ||
from torch.nn import functional as F | ||
|
@@ -21,6 +25,7 @@ | |
import flashinfer | ||
import flashinfer.fused_moe as fused_moe | ||
from flashinfer import fp4_quantize | ||
from flashinfer.testing.utils import bench_kineto | ||
|
||
BATCH_SIZES = [ | ||
1, | ||
|
@@ -35,7 +40,9 @@ | |
96, | ||
128, | ||
256, | ||
384, # NOTE ADD | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
512, | ||
768, # NOTE ADD | ||
1024, | ||
1536, | ||
2048, | ||
|
@@ -53,18 +60,40 @@ | |
FP8_DTYPE = torch.float8_e4m3fn | ||
|
||
test_configs = [ | ||
{ | ||
"hidden_size": 7168, | ||
"num_experts": 256, | ||
"top_k": 8, | ||
"intermediate_size": 256, | ||
}, | ||
{ | ||
"hidden_size": 7168, | ||
"num_experts": 32, | ||
"top_k": 8, | ||
"intermediate_size": 2048, | ||
}, | ||
# NOTE MODIFIED ADD | ||
*[ | ||
{ | ||
"hidden_size": 7168, | ||
"num_experts": num_experts, | ||
"top_k": 8, | ||
"intermediate_size": 2048, | ||
} | ||
for num_experts in [ | ||
288 // 1, | ||
288 // 2, | ||
288 // 4, | ||
288 // 8, | ||
288 // 16, | ||
288 // 32, | ||
# TODO support | ||
# 288 // 48, | ||
# 288 // 72, | ||
] | ||
], | ||
|
||
# --- old --- | ||
# { | ||
# "hidden_size": 7168, | ||
# "num_experts": 256, | ||
# "top_k": 8, | ||
# "intermediate_size": 256, | ||
# }, | ||
# { | ||
# "hidden_size": 7168, | ||
# "num_experts": 32, | ||
# "top_k": 8, | ||
# "intermediate_size": 2048, | ||
# }, | ||
] | ||
|
||
|
||
|
@@ -182,7 +211,22 @@ def bench_cutlass_fused_moe( | |
input_sf=input_sf, | ||
output=flash_output, | ||
) | ||
ms = do_bench( | ||
# NOTE MODIFIED | ||
# ms = do_bench( | ||
# lambda: fused_moe.cutlass_fused_moe( | ||
# hidden_states, | ||
# selected_experts.to(torch.int), | ||
# routing_weights, | ||
# w1_q.contiguous().view(torch.long), | ||
# w2_q.contiguous().view(torch.long), | ||
# otype, | ||
# quant_scales=quant_scales, | ||
# input_sf=input_sf, | ||
# output=flash_output, | ||
# ) | ||
# ) | ||
trace_dir = os.environ.get("BENCH_KINETO_TRACE_DIR") | ||
[time_gemm1, time_gemm2] = bench_kineto( | ||
lambda: fused_moe.cutlass_fused_moe( | ||
hidden_states, | ||
selected_experts.to(torch.int), | ||
|
@@ -193,12 +237,25 @@ def bench_cutlass_fused_moe( | |
quant_scales=quant_scales, | ||
input_sf=input_sf, | ||
output=flash_output, | ||
) | ||
), | ||
kernel_names="cutlass13device_kernelINS_4gemm6kernel", | ||
num_kernels_per_period=2, | ||
trace_path=f"{trace_dir}/{time.time()}.trace.json.gz" if trace_dir else None, | ||
) | ||
print( | ||
f"batch_size={batch_size}, num_experts={num_experts}, top_k={top_k}, intermediate_size={intermediate_size}" | ||
) | ||
print(f"execution time: {ms}ms") | ||
|
||
# NOTE MODIFIED | ||
print(f"MAIN_OUTPUT=" + json.dumps(dict( | ||
batch_size=batch_size, | ||
num_experts=num_experts, | ||
top_k=top_k, | ||
intermediate_size=intermediate_size, | ||
time_gemm1_us=time_gemm1 * 1e6, | ||
time_gemm2_us=time_gemm2 * 1e6, | ||
))) | ||
Comment on lines
+247
to
+254
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
# print( | ||
# f"batch_size={batch_size}, num_experts={num_experts}, top_k={top_k}, intermediate_size={intermediate_size}" | ||
# ) | ||
# print(f"execution time: {ms}ms") | ||
|
||
|
||
if __name__ == "__main__": | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -13,8 +13,13 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from typing import Tuple | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import json | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import os | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import sys | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import tempfile | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from typing import Tuple, Union, Optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import torch.distributed as dist | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from einops import rearrange, reduce, repeat | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -206,3 +211,118 @@ def dequantize_fp8(x, x_scale, scale_major_mode): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1 1 1 1") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
out = rearrange(x * x_scale, "s0 s1 s2 t0 t1 t2 -> (s0 t0) (s1 t1) (s2 t2)") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class empty_suppress: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __enter__(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __exit__(self, *_): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pass | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class suppress_stdout_stderr: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __enter__(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.outnull_file = open(os.devnull, 'w') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.errnull_file = open(os.devnull, 'w') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.old_stdout_fileno_undup = sys.stdout.fileno() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.old_stderr_fileno_undup = sys.stderr.fileno() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.old_stdout_fileno = os.dup(sys.stdout.fileno()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.old_stderr_fileno = os.dup(sys.stderr.fileno()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.old_stdout = sys.stdout | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.old_stderr = sys.stderr | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sys.stdout = self.outnull_file | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sys.stderr = self.errnull_file | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __exit__(self, *_): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sys.stdout = self.old_stdout | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sys.stderr = self.old_stderr | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.close(self.old_stdout_fileno) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.close(self.old_stderr_fileno) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.outnull_file.close() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.errnull_file.close() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+224
to
+256
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation of
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# copy and modified from DeepGEMM and DeepEP | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppress_kineto_output: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
trace_path: Optional[str] = None, barrier_comm_profiling: bool = False, flush_l2: bool = True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
num_kernels_per_period: int = 1): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Conflict with Nsight Systems | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
flush_l2_size = int(8e9 // 4) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# For some auto-tuning kernels with prints | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
fn() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Profile | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
with suppress(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
with profiler: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for i in range(2): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if barrier_comm_profiling: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
lhs @ rhs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for _ in range(num_tests): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if flush_l2: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
fn() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if not using_nsys: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
profiler.step() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Return 1 if using Nsight Systems | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if using_nsys: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Parse the profiling table | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
is_tuple = isinstance(kernel_names, tuple) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# print(f"prof_lines=\n" + "\n".join(prof_lines)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Save chrome traces | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if trace_path is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
print(f"export_chrome_trace to {trace_path=}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
profiler.export_chrome_trace(trace_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert all([isinstance(name, str) for name in kernel_names]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
kernel_durations = [None] * len(kernel_names) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Expand the kernels by periods | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
with tempfile.NamedTemporaryFile(suffix='.json') as tmp: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
profiler.export_chrome_trace(tmp.name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
profile_data = json.loads(Path(tmp.name).read_text()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for i, kernel_name in enumerate(kernel_names): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
events = [event for event in profile_data['traceEvents'] if kernel_name in event['name']] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
events = sorted(events, key=lambda event: event['ts']) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
durations = [event['dur'] / 1e6 for event in events] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert len(durations) % num_kernels_per_period == 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
num_kernel_patterns = len(durations) // num_kernels_per_period | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for j in range(num_kernels_per_period)] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Return execution durations | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return kernel_durations if is_tuple else kernel_durations[0] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,7 +62,8 @@ def generate_build_meta(aot_build_meta: dict) -> None: | |
"einops", | ||
"nvidia-nvshmem-cu12", | ||
"nvidia-cudnn-cu12", | ||
"nvidia-cudnn-frontend", | ||
# NOTE MODIFIED rm | ||
# "nvidia-cudnn-frontend", | ||
Comment on lines
+65
to
+66
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
] | ||
generate_build_meta({}) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider removing the unused
sys
andtime
imports to maintain clean and readable code.