diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 6c8113e8cb..3365fba923 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -60,7 +60,7 @@ @torch.no_grad() -def get_gpu_kernel_time(m, x): +def get_gpu_kernel_time(m, x, trace_filename=None): # warm up for _ in range(2): __ = m(x) @@ -72,6 +72,12 @@ def get_gpu_kernel_time(m, x): for _ in range(n_iter): __ = m(x) torch.cuda.synchronize() + + # save a trace, if requested + if trace_filename is not None: + print(f"exporting trace to {trace_filename}") + prof.export_chrome_trace(trace_filename) + # get the gpu kernel time and aggregate it num_leaf_tensors = 1 + len(list(m.parameters())) ref_times = profiler_output_to_filtered_time_by_kernel_name( @@ -161,6 +167,7 @@ def run( do_benchmarks: bool = True, shape_gen_name: str = "pow2", n_limit: Optional[int] = None, + save_profile_traces: bool = False, ): """ Args: @@ -168,6 +175,7 @@ def run( * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep` * `n_limit (optional)`: if specified, only runs `n_limit` iterations + # `save_profile_traces (optional)`: if True, saves profiling traces """ config_table = [ ["GPU", torch.cuda.get_device_name(0)], @@ -289,7 +297,11 @@ def run( # get the bf16 gpu kernel time torch._dynamo.reset() m_bf16 = torch.compile(copy.deepcopy(m_orig)) - b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x) + + bf16_trace_filename = None + if save_profile_traces: + bf16_trace_filename = f"{outfile}_{M_val}_{K_val}_{N_val}_bf16.json" + b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x, bf16_trace_filename) # get the float8 dynamic scaling gpu kernel time torch._dynamo.reset() @@ -325,7 +337,11 @@ def run( quantize_(m_fp8_dyn, config) m_fp8_dyn = torch.compile(m_fp8_dyn) - b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x) + + fp8_trace_filename = None + if save_profile_traces: + fp8_trace_filename = f"{outfile}_{M_val}_{K_val}_{N_val}_fp8.json" + b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, fp8_trace_filename) r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)