@@ -108,7 +108,6 @@ def warm_up(
108108def measure_detailed_inference_timing (
109109 model , sample , model_device , transfer_to_device_fn = torch .Tensor .to
110110):
111-
112111 try :
113112 with torch .autograd .profiler .profile (
114113 use_cuda = (model_device .type == "cuda" ), profile_memory = True
@@ -135,7 +134,6 @@ def measure_repeated_inference_timing(
135134 num_runs = 100 ,
136135 batch_size : int = None ,
137136):
138-
139137 t_c2d = []
140138 t_inf = []
141139 t_d2c = []
@@ -146,14 +144,29 @@ def measure_repeated_inference_timing(
146144 ):
147145 start_on_cpu = time ()
148146 device_sample = transfer_to_device_fn (sample , model_device )
149- start_on_device = time ()
147+
148+ if model_device .type == "cuda" :
149+ start_event = torch .cuda .Event (enable_timing = True )
150+ stop_event = torch .cuda .Event (enable_timing = True )
151+ start_event .record () # For GPU timing
152+ start_on_device = time () # For CPU timing
153+
150154 device_result = model (device_sample )
151- stop_on_device = time ()
155+
156+ if model_device .type == "cuda" :
157+ stop_event .record ()
158+ torch .cuda .synchronize ()
159+ elapsed_on_device = stop_event .elapsed_time (start_event )
160+ stop_on_device = time ()
161+ else :
162+ stop_on_device = time ()
163+ elapsed_on_device = stop_on_device - start_on_device
164+
152165 transfer_to_device_fn (device_result , "cpu" )
153166 stop_on_cpu = time ()
154167
155168 t_c2d .append (start_on_device - start_on_cpu )
156- t_inf .append (stop_on_device - start_on_device )
169+ t_inf .append (elapsed_on_device )
157170 t_d2c .append (stop_on_cpu - stop_on_device )
158171 t_tot .append (stop_on_cpu - start_on_cpu )
159172
@@ -328,7 +341,11 @@ def benchmark(
328341 batch_size = 1 ,
329342 )
330343
331- flops = measure_flops (model , sample1 , print_details )
344+ with torch .no_grad ():
345+ flops = measure_flops (
346+ model , transfer_to_device_fn (sample1 , model_device ), print_details
347+ )
348+
332349 if _is_valid (flops ):
333350 results ["flops" ] = flops
334351 print_fn (f"Model FLOPs: { flops } ({ format_num (flops )} )" )
0 commit comments