@@ -1580,6 +1580,34 @@ def run_forward(model):
15801580 self .assertTrue (torch .allclose (output_without_group_offloading , output_with_group_offloading3 , atol = 1e-5 ))
15811581 self .assertTrue (torch .allclose (output_without_group_offloading , output_with_group_offloading4 , atol = 1e-5 ))
15821582
1583+ @parameterized .expand ([(False , "block_level" ), (True , "leaf_level" )])
1584+ @require_torch_accelerator
1585+ @torch .no_grad ()
1586+ def test_group_offloading_with_layerwise_casting (self , record_stream , offload_type ):
1587+ torch .manual_seed (0 )
1588+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1589+ model = self .model_class (** init_dict )
1590+
1591+ if not getattr (model , "_supports_group_offloading" , True ):
1592+ return
1593+
1594+ model .to (torch_device )
1595+ model .eval ()
1596+ _ = model (** inputs_dict )[0 ]
1597+
1598+ torch .manual_seed (0 )
1599+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1600+ storage_dtype , compute_dtype = torch .float16 , torch .float32
1601+ inputs_dict = cast_maybe_tensor_dtype (inputs_dict , torch .float32 , compute_dtype )
1602+ model = self .model_class (** init_dict )
1603+ model .eval ()
1604+ additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group" : 1 }
1605+ model .enable_group_offload (
1606+ torch_device , offload_type = offload_type , use_stream = True , record_stream = record_stream , ** additional_kwargs
1607+ )
1608+ model .enable_layerwise_casting (storage_dtype = storage_dtype , compute_dtype = compute_dtype )
1609+ _ = model (** inputs_dict )[0 ]
1610+
15831611 def test_auto_model (self , expected_max_diff = 5e-5 ):
15841612 if self .forward_requires_fresh_args :
15851613 model = self .model_class (** self .init_dict )
0 commit comments