@@ -1580,10 +1580,10 @@ def run_forward(model):
15801580 self .assertTrue (torch .allclose (output_without_group_offloading , output_with_group_offloading3 , atol = 1e-5 ))
15811581 self .assertTrue (torch .allclose (output_without_group_offloading , output_with_group_offloading4 , atol = 1e-5 ))
15821582
1583- @parameterized .expand ([(False , torch . float16 , torch . float32 ), (True , torch . float16 , torch . float32 )])
1583+ @parameterized .expand ([(False , "block_level" ), (True , "leaf_level" )])
15841584 @require_torch_accelerator
15851585 @torch .no_grad ()
1586- def test_group_offloading_with_layerwise_casting (self , record_stream , storage_dtype , compute_dtype ):
1586+ def test_group_offloading_with_layerwise_casting (self , record_stream , offload_type ):
15871587 torch .manual_seed (0 )
15881588 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
15891589 model = self .model_class (** init_dict )
@@ -1597,26 +1597,15 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, storage_dt
15971597
15981598 torch .manual_seed (0 )
15991599 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1600+ storage_dtype , compute_dtype = torch .float16 , torch .float32
16001601 inputs_dict = cast_maybe_tensor_dtype (inputs_dict , torch .float32 , compute_dtype )
16011602 model = self .model_class (** init_dict )
16021603 model .eval ()
1603- model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 1 )
1604- model .enable_layerwise_casting (storage_dtype = storage_dtype , compute_dtype = compute_dtype )
1605- _ = model (** inputs_dict )[0 ]
1606-
1607- torch .manual_seed (0 )
1608- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1609- inputs_dict = cast_maybe_tensor_dtype (inputs_dict , torch .float32 , compute_dtype )
1610- model = self .model_class (** init_dict )
1611- model .eval ()
1604+ additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group" : 1 }
16121605 model .enable_group_offload (
1613- torch_device ,
1614- offload_type = "block_level" ,
1615- num_blocks_per_group = 1 ,
1616- use_stream = True ,
1617- non_blocking = True ,
1618- record_stream = record_stream ,
1606+ torch_device , offload_type = offload_type , use_stream = True , record_stream = record_stream , ** additional_kwargs
16191607 )
1608+ model .enable_layerwise_casting (storage_dtype = storage_dtype , compute_dtype = compute_dtype )
16201609 _ = model (** inputs_dict )[0 ]
16211610
16221611 def test_auto_model (self , expected_max_diff = 5e-5 ):
0 commit comments