@@ -1744,6 +1744,10 @@ def test_push_to_hub_library_name(self):
17441744 delete_repo (self .repo_id , token = TOKEN )
17451745
17461746
1747+ @require_torch_gpu
1748+ @require_torch_2
1749+ @is_torch_compile
1750+ @slow
17471751class TorchCompileTesterMixin :
17481752 def setUp (self ):
17491753 # clean up the VRAM before each test
@@ -1759,12 +1763,7 @@ def tearDown(self):
17591763 gc .collect ()
17601764 backend_empty_cache (torch_device )
17611765
1762- @require_torch_gpu
1763- @require_torch_2
1764- @is_torch_compile
1765- @slow
17661766 def test_torch_compile_recompilation_and_graph_break (self ):
1767- torch .compiler .reset ()
17681767 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
17691768
17701769 model = self .model_class (** init_dict ).to (torch_device )
@@ -1778,6 +1777,31 @@ def test_torch_compile_recompilation_and_graph_break(self):
17781777 _ = model (** inputs_dict )
17791778 _ = model (** inputs_dict )
17801779
1780+ def test_compile_with_group_offloading (self ):
1781+ torch ._dynamo .config .cache_size_limit = 10000
1782+
1783+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1784+ model = self .model_class (** init_dict )
1785+
1786+ if not getattr (model , "_supports_group_offloading" , True ):
1787+ return
1788+
1789+ model .eval ()
1790+ # TODO: Can test for other group offloading kwargs later if needed.
1791+ group_offload_kwargs = {
1792+ "onload_device" : "cuda" ,
1793+ "offload_device" : "cpu" ,
1794+ "offload_type" : "block_level" ,
1795+ "num_blocks_per_group" : 1 ,
1796+ "use_stream" : True ,
1797+ "non_blocking" : True ,
1798+ }
1799+ model .enable_group_offload (** group_offload_kwargs )
1800+ model .compile ()
1801+ with torch .no_grad ():
1802+ _ = model (** inputs_dict )
1803+ _ = model (** inputs_dict )
1804+
17811805
17821806@slow
17831807@require_torch_2
0 commit comments