3737_GROUP_OFFLOADING = "group_offloading"
3838_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
3939_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
40-
40+ _GROUP_ID_LAZY_LEAF = "lazy_leafs"
41+ _GROUP_ID_UNMATCHED_GROUP = "top_level_unmatched_modules"
4142_SUPPORTED_PYTORCH_LAYERS = (
4243 torch .nn .Conv1d , torch .nn .Conv2d , torch .nn .Conv3d ,
4344 torch .nn .ConvTranspose1d , torch .nn .ConvTranspose2d , torch .nn .ConvTranspose3d ,
@@ -84,7 +85,7 @@ def __init__(
8485
8586 if self .offload_to_disk_path :
8687 self ._group_id = _group_id
87- short_hash = self . _compute_group_hash (self ._group_id )
88+ short_hash = _compute_group_hash (self ._group_id )
8889 self .safetensors_file_path = os .path .join (self .offload_to_disk_path , f"group_{ short_hash } .safetensors" )
8990
9091 all_tensors = []
@@ -265,11 +266,6 @@ def offload_(self):
265266 for buffer in self .buffers :
266267 buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
267268
268- def _compute_group_hash (self , group_id ):
269- hashed_id = hashlib .sha256 (group_id .encode ("utf-8" )).hexdigest ()
270- # first 16 characters for a reasonably short but unique name
271- return hashed_id [:16 ]
272-
273269
274270class GroupOffloadingHook (ModelHook ):
275271 r"""
@@ -663,7 +659,7 @@ def _apply_group_offloading_block_level(
663659 stream = None ,
664660 record_stream = False ,
665661 onload_self = True ,
666- _group_id = "top_level_unmatched_modules" ,
662+ _group_id = _GROUP_ID_UNMATCHED_GROUP ,
667663 )
668664 if stream is None :
669665 _apply_group_offloading_hook (module , unmatched_group , None )
@@ -800,7 +796,7 @@ def _apply_group_offloading_leaf_level(
800796 record_stream = False ,
801797 low_cpu_mem_usage = low_cpu_mem_usage ,
802798 onload_self = True ,
803- name = "lazy_leafs" ,
799+ _group_id = _GROUP_ID_LAZY_LEAF ,
804800 )
805801 _apply_lazy_group_offloading_hook (module , unmatched_group , None )
806802
@@ -910,6 +906,12 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
910906 raise ValueError ("Group offloading is not enabled for the provided module." )
911907
912908
909+ def _compute_group_hash (group_id ):
910+ hashed_id = hashlib .sha256 (group_id .encode ("utf-8" )).hexdigest ()
911+ # first 16 characters for a reasonably short but unique name
912+ return hashed_id [:16 ]
913+
914+
913915def _get_expected_safetensors_files (
914916 module : torch .nn .Module ,
915917 offload_to_disk_path : str ,
@@ -919,8 +921,7 @@ def _get_expected_safetensors_files(
919921 expected_files = set ()
920922
921923 def get_hashed_filename (group_id : str ) -> str :
922- hashed_id = hashlib .sha256 (group_id .encode ("utf-8" )).hexdigest ()
923- short_hash = hashed_id [:16 ]
924+ short_hash = _compute_group_hash (group_id )
924925 return os .path .join (offload_to_disk_path , f"group_{ short_hash } .safetensors" )
925926
926927 if offload_type == "block_level" :
@@ -942,7 +943,7 @@ def get_hashed_filename(group_id: str) -> str:
942943 expected_files .add (get_hashed_filename (group_id ))
943944
944945 # Handle the group for unmatched top-level modules and parameters
945- group_id = "top_level_unmatched_modules"
946+ group_id = _GROUP_ID_UNMATCHED_GROUP
946947 expected_files .add (get_hashed_filename (group_id ))
947948
948949 elif offload_type == "leaf_level" :
@@ -972,6 +973,7 @@ def get_hashed_filename(group_id: str) -> str:
972973 for parent_name in parent_to_tensors :
973974 # A file is expected for each parent that gathers orphaned tensors
974975 expected_files .add (get_hashed_filename (parent_name ))
976+ expected_files .add (get_hashed_filename (_GROUP_ID_LAZY_LEAF ))
975977
976978 else :
977979 raise ValueError (f"Unsupported offload_type: { offload_type } " )
0 commit comments