@@ -11,7 +11,7 @@ def benchmark_fn(f, *args, **kwargs):
1111    t0  =  benchmark .Timer (
1212        stmt = "f(*args, **kwargs)" ,
1313        globals = {"args" : args , "kwargs" : kwargs , "f" : f },
14-         num_threads = torch . get_num_threads () ,
14+         num_threads = 1 ,
1515    )
1616    return  f"{ (t0 .blocked_autorange ().mean ):.3f}  " 
1717
@@ -53,10 +53,6 @@ def run_benchmark(self):
5353        model  =  self .initialize_model ()  # Takes care of device placement. 
5454        input_dict  =  self .get_input_dict ()  # Takes care of device placement. 
5555
56-         # warmup 
57-         for  _  in  range (5 ):
58-             _  =  model (** input_dict )
59- 
6056        time  =  benchmark_fn (lambda  model , input_dict : model (** input_dict ), model , input_dict )
6157        memory  =  torch .cuda .max_memory_allocated () /  (1024 ** 3 )
6258        memory  =  float (f"{ memory :.2f}  " )
@@ -69,9 +65,9 @@ def run_benchmark(self):
6965        compile_stats  =  None 
7066        if  self .compile_kwargs  is  not   None :
7167            model  =  self .initialize_model ()
72-             with   torch . _inductor . utils . fresh_inductor_cache (): 
73-                  model .compile (** self .compile_kwargs )
74-                  time  =  benchmark_fn (lambda  model , input_dict : model (** input_dict ), model , input_dict )
68+             input_dict   =   self . get_input_dict () 
69+             model .compile (** self .compile_kwargs )
70+             time  =  benchmark_fn (lambda  model , input_dict : model (** input_dict ), model , input_dict )
7571            memory  =  torch .cuda .max_memory_allocated () /  (1024 ** 3 )
7672            memory  =  float (f"{ memory :.2f}  " )
7773            compile_stats  =  {"time" : time , "memory" : memory }
0 commit comments