@@ -1829,6 +1829,10 @@ def test_push_to_hub_library_name(self):
18291829        delete_repo (self .repo_id , token = TOKEN )
18301830
18311831
1832+ @require_torch_gpu  
1833+ @require_torch_2  
1834+ @is_torch_compile  
1835+ @slow  
18321836class  TorchCompileTesterMixin :
18331837    def  setUp (self ):
18341838        # clean up the VRAM before each test 
@@ -1844,12 +1848,7 @@ def tearDown(self):
18441848        gc .collect ()
18451849        backend_empty_cache (torch_device )
18461850
1847-     @require_torch_gpu  
1848-     @require_torch_2  
1849-     @is_torch_compile  
1850-     @slow  
18511851    def  test_torch_compile_recompilation_and_graph_break (self ):
1852-         torch .compiler .reset ()
18531852        init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
18541853
18551854        model  =  self .model_class (** init_dict ).to (torch_device )
@@ -1863,6 +1862,31 @@ def test_torch_compile_recompilation_and_graph_break(self):
18631862            _  =  model (** inputs_dict )
18641863            _  =  model (** inputs_dict )
18651864
1865+     def  test_compile_with_group_offloading (self ):
1866+         torch ._dynamo .config .cache_size_limit  =  10000 
1867+ 
1868+         init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
1869+         model  =  self .model_class (** init_dict )
1870+ 
1871+         if  not  getattr (model , "_supports_group_offloading" , True ):
1872+             return 
1873+ 
1874+         model .eval ()
1875+         # TODO: Can test for other group offloading kwargs later if needed. 
1876+         group_offload_kwargs  =  {
1877+             "onload_device" : "cuda" ,
1878+             "offload_device" : "cpu" ,
1879+             "offload_type" : "block_level" ,
1880+             "num_blocks_per_group" : 1 ,
1881+             "use_stream" : True ,
1882+             "non_blocking" : True ,
1883+         }
1884+         model .enable_group_offload (** group_offload_kwargs )
1885+         model .compile ()
1886+         with  torch .no_grad ():
1887+             _  =  model (** inputs_dict )
1888+             _  =  model (** inputs_dict )
1889+ 
18661890
18671891@slow  
18681892@require_torch_2  
0 commit comments