@@ -1728,17 +1728,21 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17281728 has_safetensors = glob .glob (f"{ tmpdir } /*.safetensors" )
17291729 # Group offloading with disk support related checks.
17301730 self .assertTrue (has_safetensors , "No safetensors found in the directory." )
1731- is_correct , extra_files , missing_files = _check_safetensors_serialization (
1732- module = model ,
1733- offload_to_disk_path = tmpdir ,
1734- offload_type = offload_type ,
1735- num_blocks_per_group = num_blocks_per_group ,
1736- )
1737- if not is_correct :
1738- if extra_files :
1739- raise ValueError (f"Found extra files: { ', ' .join (extra_files )} " )
1740- elif missing_files :
1741- raise ValueError (f"Following files are missing: { ', ' .join (missing_files )} " )
1731+
1732+ # For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
1733+ # in nature. So, skip it.
1734+ if offload_type != "leaf_level" :
1735+ is_correct , extra_files , missing_files = _check_safetensors_serialization (
1736+ module = model ,
1737+ offload_to_disk_path = tmpdir ,
1738+ offload_type = offload_type ,
1739+ num_blocks_per_group = num_blocks_per_group ,
1740+ )
1741+ if not is_correct :
1742+ if extra_files :
1743+ raise ValueError (f"Found extra files: { ', ' .join (extra_files )} " )
1744+ elif missing_files :
1745+ raise ValueError (f"Following files are missing: { ', ' .join (missing_files )} " )
17421746
17431747 output_with_group_offloading = model (** inputs_dict )[0 ]
17441748 self .assertTrue (
0 commit comments