Skip to content

Commit 535dcd1

Browse files
committed
tests
1 parent 1d4ca61 commit 535dcd1

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tests/models/test_modeling_common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,8 +1525,9 @@ def get_memory_usage(storage_dtype, compute_dtype):
15251525
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
15261526
)
15271527

1528+
@parameterized.expand([False, True])
15281529
@require_torch_gpu
1529-
def test_group_offloading(self):
1530+
def test_group_offloading(self, record_stream):
15301531
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
15311532
torch.manual_seed(0)
15321533

@@ -1566,7 +1567,9 @@ def run_forward(model):
15661567

15671568
torch.manual_seed(0)
15681569
model = self.model_class(**init_dict)
1569-
model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True)
1570+
model.enable_group_offload(
1571+
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
1572+
)
15701573
output_with_group_offloading4 = run_forward(model)
15711574

15721575
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))

0 commit comments

Comments
 (0)