@@ -1581,6 +1581,30 @@ def run_forward(model):
15811581 self .assertTrue (torch .allclose (output_without_group_offloading , output_with_group_offloading3 , atol = 1e-5 ))
15821582 self .assertTrue (torch .allclose (output_without_group_offloading , output_with_group_offloading4 , atol = 1e-5 ))
15831583
1584+ @parameterized .expand ([False , True ])
1585+ @require_torch_accelerator
1586+ def test_group_offloading_with_training (self , use_stream ):
1587+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1588+ model = self .model_class (** init_dict )
1589+ if not getattr (model , "_supports_group_offloading" , True ):
1590+ return
1591+
1592+ model .enable_group_offload (
1593+ torch_device , offload_type = "block_level" , num_blocks_per_group = 1 , use_stream = use_stream
1594+ )
1595+ model .train ()
1596+
1597+ output = model (** inputs_dict )
1598+
1599+ if isinstance (output , dict ):
1600+ output = output .to_tuple ()[0 ]
1601+
1602+ input_tensor = inputs_dict [self .main_input_name ]
1603+ noise = torch .randn ((input_tensor .shape [0 ],) + self .output_shape ).to (torch_device )
1604+ loss = torch .nn .functional .mse_loss (output , noise )
1605+
1606+ loss .backward ()
1607+
15841608 def test_auto_model (self , expected_max_diff = 5e-5 ):
15851609 if self .forward_requires_fresh_args :
15861610 model = self .model_class (** self .init_dict )
0 commit comments