Skip to content

Commit aacf625

Browse files
committed
feedback
1 parent ecdb8e3 commit aacf625

File tree

1 file changed

+6
-17
lines changed

1 file changed

+6
-17
lines changed

tests/models/test_modeling_common.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,10 +1580,10 @@ def run_forward(model):
15801580
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
15811581
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
15821582

1583-
@parameterized.expand([(False, torch.float16, torch.float32), (True, torch.float16, torch.float32)])
1583+
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
15841584
@require_torch_accelerator
15851585
@torch.no_grad()
1586-
def test_group_offloading_with_layerwise_casting(self, record_stream, storage_dtype, compute_dtype):
1586+
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
15871587
torch.manual_seed(0)
15881588
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
15891589
model = self.model_class(**init_dict)
@@ -1597,26 +1597,15 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, storage_dt
15971597

15981598
torch.manual_seed(0)
15991599
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1600+
storage_dtype, compute_dtype = torch.float16, torch.float32
16001601
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
16011602
model = self.model_class(**init_dict)
16021603
model.eval()
1603-
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
1604-
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
1605-
_ = model(**inputs_dict)[0]
1606-
1607-
torch.manual_seed(0)
1608-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1609-
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
1610-
model = self.model_class(**init_dict)
1611-
model.eval()
1604+
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
16121605
model.enable_group_offload(
1613-
torch_device,
1614-
offload_type="block_level",
1615-
num_blocks_per_group=1,
1616-
use_stream=True,
1617-
non_blocking=True,
1618-
record_stream=record_stream,
1606+
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
16191607
)
1608+
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
16201609
_ = model(**inputs_dict)[0]
16211610

16221611
def test_auto_model(self, expected_max_diff=5e-5):

0 commit comments

Comments
 (0)