Skip to content

Commit a9b7abe

Browse files
committed
updates
1 parent 6639f25 commit a9b7abe

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

tests/models/test_modeling_common.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,17 +1728,21 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17281728
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
17291729
# Group offloading with disk support related checks.
17301730
self.assertTrue(has_safetensors, "No safetensors found in the directory.")
1731-
is_correct, extra_files, missing_files = _check_safetensors_serialization(
1732-
module=model,
1733-
offload_to_disk_path=tmpdir,
1734-
offload_type=offload_type,
1735-
num_blocks_per_group=num_blocks_per_group,
1736-
)
1737-
if not is_correct:
1738-
if extra_files:
1739-
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
1740-
elif missing_files:
1741-
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
1731+
1732+
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
1733+
# in nature. So, skip it.
1734+
if offload_type != "leaf_level":
1735+
is_correct, extra_files, missing_files = _check_safetensors_serialization(
1736+
module=model,
1737+
offload_to_disk_path=tmpdir,
1738+
offload_type=offload_type,
1739+
num_blocks_per_group=num_blocks_per_group,
1740+
)
1741+
if not is_correct:
1742+
if extra_files:
1743+
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
1744+
elif missing_files:
1745+
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
17421746

17431747
output_with_group_offloading = model(**inputs_dict)[0]
17441748
self.assertTrue(

0 commit comments

Comments
 (0)