55# LICENSE file in the root directory of this source tree.
66# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
77
8+ import argparse
89import itertools
910from dataclasses import dataclass
1011from typing import List
1516from tqdm import tqdm
1617from triton .testing import do_bench
1718
18- from benchmarks .utils import bench_fwd_bwd_microseconds
19+ from benchmarks .utils import bench_fwd_bwd_microseconds , profile_fwd_bwd
1920from torchao .prototype .blockwise_fp8_training .linear import Float8BlockwiseLinear
2021
2122device = torch .device ("cuda" )
@@ -71,7 +72,7 @@ def get_configs() -> List[ExperimentConfig]:
7172 return configs
7273
7374
74- def run_experiment (config : ExperimentConfig ) -> ExperimentResult :
75+ def run_experiment (config : ExperimentConfig , profile = False , use_compile = False ) -> ExperimentResult :
7576 M , N , K = config .m , config .n , config .k
7677 inputs = torch .randn (M , K , dtype = config .out_dtype , device = "cuda" )
7778 bf16_linear = torch .nn .Linear (K , N , dtype = config .out_dtype , device = "cuda" )
@@ -83,49 +84,59 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
8384 )
8485
8586 def warmup (func , * args , ** kwargs ):
86- for _ in range (10 ):
87+ for _ in range (3 ):
8788 func (* args , ** kwargs )
8889
89- def fwd_bwd (func , inputs , labels , * args , ** kwargs ):
90- out = func (inputs , * args , ** kwargs )
91- loss = F .mse_loss (out , labels )
92- loss .backward ()
93- torch .cuda .synchronize ()
9490
95- # Warmup then run bf16 torch.mm
91+ # bfloat16 bench and profile
9692 labels = inputs .new_empty (M , N ).fill_ (1.0 )
97- warmup (fwd_bwd , bf16_linear , inputs , labels )
98-
99- bf16_linear_us = benchmark_cuda_function_in_microseconds (
100- fwd_bwd , bf16_linear , inputs , labels
93+ bf16_linear_us = bench_fwd_bwd_microseconds (
94+ bf16_linear ,
95+ inputs ,
96+ labels = labels ,
97+ use_compile = use_compile ,
10198 )
102-
103- # Warm up then run triton bench
104- warmup (
105- fwd_bwd ,
106- fp8_triton_linear ,
107- inputs ,
108- labels ,
99+ if profile :
100+ print ("Profiling bf16_linear" )
101+ profile_fwd_bwd (
102+ bf16_linear ,
103+ inputs ,
104+ labels = labels ,
105+ profile_name = "bf16_linear_profile" ,
106+ use_compile = use_compile ,
109107 )
110108
109+ # FP8 triton bench and profile
111110 fp8_triton_linear_us = bench_fwd_bwd_microseconds (
112111 fp8_triton_linear ,
113112 inputs ,
114113 labels = labels ,
115114 )
115+ if profile :
116+ print ("Profiling fp8_triton_linear" )
117+ profile_fwd_bwd (
118+ fp8_triton_linear ,
119+ inputs ,
120+ labels = labels ,
121+ profile_name = "fp8_triton_linear_profile" ,
122+ )
116123
117- warmup (
118- fwd_bwd ,
119- fp8_scaled_mm_linear ,
120- inputs ,
121- labels ,
122- )
123-
124+ # FP8 torch._scaled_mm bench and profile
124125 fp8_scaled_mm_linear_us = bench_fwd_bwd_microseconds (
125126 fp8_scaled_mm_linear ,
126127 inputs ,
127128 labels = labels ,
129+ use_compile = use_compile ,
128130 )
131+ if profile :
132+ print ("Profiling fp8_scaled_mm_linear" )
133+ profile_fwd_bwd (
134+ fp8_scaled_mm_linear ,
135+ inputs ,
136+ labels = labels ,
137+ profile_name = "fp8_scaled_mm_linear_profile" ,
138+ use_compile = use_compile ,
139+ )
129140
130141 return ExperimentResult (
131142 bf16_linear_us = bf16_linear_us ,
@@ -165,17 +176,21 @@ def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
165176 return do_bench (lambda : f (* args , ** kwargs ), return_mode = "median" ) * 1e3
166177
167178
168- def main ():
179+ def main (args : argparse . Namespace ):
169180 torch .random .manual_seed (123 )
170181 configs = get_configs ()
171182 results = []
172183 for config in tqdm (configs ):
173- result = run_experiment (config )
184+ result = run_experiment (config , profile = args . profile , use_compile = args . compile )
174185 results .append (Experiment (config = config , result = result ))
175186
176187 # Use Tabulate to print results
177188 print_results (results )
178189
179190
180191if __name__ == "__main__" :
181- main ()
192+ parser = argparse .ArgumentParser ()
193+ parser .add_argument ("--profile" , action = "store_true" , help = "Enable profiling" )
194+ parser .add_argument ("--compile" , action = "store_true" , help = "Enable compilation" )
195+ args = parser .parse_args ()
196+ main (args )
0 commit comments