diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 58edeb55c4b1..4602c042d735 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1580,6 +1580,30 @@ def run_forward(model): self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)) + @parameterized.expand([False, True]) + @require_torch_accelerator + def test_group_offloading_with_training(self, use_stream): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + if not getattr(model, "_supports_group_offloading", True): + return + + model.enable_group_offload( + torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=use_stream + ) + model.train() + + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + input_tensor = inputs_dict[self.main_input_name] + noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + + loss.backward() + def test_auto_model(self, expected_max_diff=5e-5): if self.forward_requires_fresh_args: model = self.model_class(**self.init_dict)