Skip to content

Commit 131ed8e

Browse files
committed
add test for group_offloading with training.
1 parent fb29132 commit 131ed8e

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tests/models/test_modeling_common.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,30 @@ def run_forward(model):
15811581
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
15821582
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
15831583

1584+
@parameterized.expand([False, True])
1585+
@require_torch_accelerator
1586+
def test_group_offloading_with_training(self, use_stream):
1587+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1588+
model = self.model_class(**init_dict)
1589+
if not getattr(model, "_supports_group_offloading", True):
1590+
return
1591+
1592+
model.enable_group_offload(
1593+
torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=use_stream
1594+
)
1595+
model.train()
1596+
1597+
output = model(**inputs_dict)
1598+
1599+
if isinstance(output, dict):
1600+
output = output.to_tuple()[0]
1601+
1602+
input_tensor = inputs_dict[self.main_input_name]
1603+
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
1604+
loss = torch.nn.functional.mse_loss(output, noise)
1605+
1606+
loss.backward()
1607+
15841608
def test_auto_model(self, expected_max_diff=5e-5):
15851609
if self.forward_requires_fresh_args:
15861610
model = self.model_class(**self.init_dict)

0 commit comments

Comments
 (0)