@@ -146,6 +146,25 @@ def get_timing_stats_cpu(elapsed_times: list[float]):
146146 return stats
147147
148148
149+ def measure_performance (model_call , args , compiler ):
150+ if args .device == "cuda" :
151+ times = time_execution_with_cuda_event (
152+ model_call ,
153+ num_warmup = args .warmup ,
154+ num_trials = args .trials ,
155+ device = torch .device ("cuda:0" ),
156+ )
157+ return get_timing_stats (times )
158+ else :
159+ times = time_execution_naive (
160+ model_call ,
161+ compiler .synchronize ,
162+ num_warmup = args .warmup ,
163+ num_trials = args .trials ,
164+ )
165+ return get_timing_stats_cpu (times )
166+
167+
149168def test_single_model (args ):
150169 compiler = get_compiler_backend (args )
151170 input_dict = get_input_dict (args )
@@ -192,38 +211,8 @@ def test_single_model(args):
192211 eager_model_call = lambda : model (** input_dict )
193212 compiled_model_call = lambda : compiled_model (** input_dict )
194213
195- if args .device == "cuda" :
196- eager_times = time_execution_with_cuda_event (
197- eager_model_call ,
198- num_warmup = args .warmup ,
199- num_trials = args .trials ,
200- device = torch .device ("cuda:0" ),
201- )
202- eager_stats = get_timing_stats (eager_times )
203-
204- compiled_times = time_execution_with_cuda_event (
205- compiled_model_call ,
206- num_warmup = args .warmup ,
207- num_trials = args .trials ,
208- device = torch .device ("cuda:0" ),
209- )
210- compiled_stats = get_timing_stats (compiled_times )
211- else :
212- eager_times = time_execution_naive (
213- eager_model_call ,
214- compiler .synchronize ,
215- num_warmup = args .warmup ,
216- num_trials = args .trials ,
217- )
218- eager_stats = get_timing_stats_cpu (eager_times )
219-
220- compiled_times = time_execution_naive (
221- compiled_model_call ,
222- compiler .synchronize ,
223- num_warmup = args .warmup ,
224- num_trials = args .trials ,
225- )
226- compiled_stats = get_timing_stats_cpu (compiled_times )
214+ eager_stats = measure_performance (eager_model_call , args , compiler )
215+ compiled_stats = measure_performance (compiled_model_call , args , compiler )
227216
228217 expected_out = eager_model_call ()
229218 compiled_out = compiled_model_call ()
0 commit comments