@@ -1744,6 +1744,10 @@ def test_push_to_hub_library_name(self):
17441744        delete_repo (self .repo_id , token = TOKEN )
17451745
17461746
1747+ @require_torch_gpu  
1748+ @require_torch_2  
1749+ @is_torch_compile  
1750+ @slow  
17471751class  TorchCompileTesterMixin :
17481752    def  setUp (self ):
17491753        # clean up the VRAM before each test 
@@ -1759,12 +1763,7 @@ def tearDown(self):
17591763        gc .collect ()
17601764        backend_empty_cache (torch_device )
17611765
1762-     @require_torch_gpu  
1763-     @require_torch_2  
1764-     @is_torch_compile  
1765-     @slow  
17661766    def  test_torch_compile_recompilation_and_graph_break (self ):
1767-         torch .compiler .reset ()
17681767        init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
17691768
17701769        model  =  self .model_class (** init_dict ).to (torch_device )
@@ -1778,6 +1777,31 @@ def test_torch_compile_recompilation_and_graph_break(self):
17781777            _  =  model (** inputs_dict )
17791778            _  =  model (** inputs_dict )
17801779
1780+     def  test_compile_with_group_offloading (self ):
1781+         torch ._dynamo .config .cache_size_limit  =  10000 
1782+ 
1783+         init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
1784+         model  =  self .model_class (** init_dict )
1785+ 
1786+         if  not  getattr (model , "_supports_group_offloading" , True ):
1787+             return 
1788+ 
1789+         model .eval ()
1790+         # TODO: Can test for other group offloading kwargs later if needed. 
1791+         group_offload_kwargs  =  {
1792+             "onload_device" : "cuda" ,
1793+             "offload_device" : "cpu" ,
1794+             "offload_type" : "block_level" ,
1795+             "num_blocks_per_group" : 1 ,
1796+             "use_stream" : True ,
1797+             "non_blocking" : True ,
1798+         }
1799+         model .enable_group_offload (** group_offload_kwargs )
1800+         model .compile ()
1801+         with  torch .no_grad ():
1802+             _  =  model (** inputs_dict )
1803+             _  =  model (** inputs_dict )
1804+ 
17811805
17821806@slow  
17831807@require_torch_2  
0 commit comments