1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import glob
1615import hashlib
1716import os
1817from contextlib import contextmanager , nullcontext
3736_GROUP_OFFLOADING = "group_offloading"
3837_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
3938_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
40- _GROUP_ID_LAZY_LEAF = "lazy_leafs"
41- _GROUP_ID_UNMATCHED_GROUP = "top_level_unmatched_modules"
39+ GROUP_ID_LAZY_LEAF = "lazy_leafs"
4240_SUPPORTED_PYTORCH_LAYERS = (
4341 torch .nn .Conv1d , torch .nn .Conv2d , torch .nn .Conv3d ,
4442 torch .nn .ConvTranspose1d , torch .nn .ConvTranspose2d , torch .nn .ConvTranspose3d ,
@@ -609,7 +607,7 @@ def _apply_group_offloading_block_level(
609607
610608 for i in range (0 , len (submodule ), num_blocks_per_group ):
611609 current_modules = submodule [i : i + num_blocks_per_group ]
612- group_id = f"{ name } _{ i } _{ i + len (current_modules )- 1 } "
610+ group_id = f"{ name } _{ i } _{ i + len (current_modules ) - 1 } "
613611 group = ModuleGroup (
614612 modules = current_modules ,
615613 offload_device = offload_device ,
@@ -622,7 +620,7 @@ def _apply_group_offloading_block_level(
622620 record_stream = record_stream ,
623621 low_cpu_mem_usage = low_cpu_mem_usage ,
624622 onload_self = True ,
625- _group_id = group_id ,
623+ group_id = group_id ,
626624 )
627625 matched_module_groups .append (group )
628626 for j in range (i , i + len (current_modules )):
@@ -657,7 +655,7 @@ def _apply_group_offloading_block_level(
657655 stream = None ,
658656 record_stream = False ,
659657 onload_self = True ,
660- _group_id = _GROUP_ID_UNMATCHED_GROUP ,
658+ group_id = f" { module . __class__ . __name__ } _unmatched_group" ,
661659 )
662660 if stream is None :
663661 _apply_group_offloading_hook (module , unmatched_group , None )
@@ -724,7 +722,7 @@ def _apply_group_offloading_leaf_level(
724722 record_stream = record_stream ,
725723 low_cpu_mem_usage = low_cpu_mem_usage ,
726724 onload_self = True ,
727- _group_id = name ,
725+ group_id = name ,
728726 )
729727 _apply_group_offloading_hook (submodule , group , None )
730728 modules_with_group_offloading .add (name )
@@ -772,7 +770,7 @@ def _apply_group_offloading_leaf_level(
772770 record_stream = record_stream ,
773771 low_cpu_mem_usage = low_cpu_mem_usage ,
774772 onload_self = True ,
775- _group_id = name ,
773+ group_id = name ,
776774 )
777775 _apply_group_offloading_hook (parent_module , group , None )
778776
@@ -794,7 +792,7 @@ def _apply_group_offloading_leaf_level(
794792 record_stream = False ,
795793 low_cpu_mem_usage = low_cpu_mem_usage ,
796794 onload_self = True ,
797- _group_id = _GROUP_ID_LAZY_LEAF ,
795+ group_id = GROUP_ID_LAZY_LEAF ,
798796 )
799797 _apply_lazy_group_offloading_hook (module , unmatched_group , None )
800798
@@ -908,90 +906,3 @@ def _compute_group_hash(group_id):
908906 hashed_id = hashlib .sha256 (group_id .encode ("utf-8" )).hexdigest ()
909907 # first 16 characters for a reasonably short but unique name
910908 return hashed_id [:16 ]
911-
912-
913- def _get_expected_safetensors_files (
914- module : torch .nn .Module ,
915- offload_to_disk_path : str ,
916- offload_type : str ,
917- num_blocks_per_group : Optional [int ] = None ,
918- ) -> Set [str ]:
919- expected_files = set ()
920-
921- def get_hashed_filename (group_id : str ) -> str :
922- short_hash = _compute_group_hash (group_id )
923- return os .path .join (offload_to_disk_path , f"group_{ short_hash } .safetensors" )
924-
925- if offload_type == "block_level" :
926- if num_blocks_per_group is None :
927- raise ValueError ("num_blocks_per_group must be provided for 'block_level' offloading." )
928-
929- # Handle groups of ModuleList and Sequential blocks
930- for name , submodule in module .named_children ():
931- if not isinstance (submodule , (torch .nn .ModuleList , torch .nn .Sequential )):
932- continue
933-
934- for i in range (0 , len (submodule ), num_blocks_per_group ):
935- current_modules = submodule [i : i + num_blocks_per_group ]
936- if not current_modules :
937- continue
938- start_idx = i
939- end_idx = i + len (current_modules ) - 1
940- group_id = f"{ name } .{ start_idx } _to_{ end_idx } "
941- expected_files .add (get_hashed_filename (group_id ))
942-
943- # Handle the group for unmatched top-level modules and parameters
944- group_id = _GROUP_ID_UNMATCHED_GROUP
945- expected_files .add (get_hashed_filename (group_id ))
946-
947- elif offload_type == "leaf_level" :
948- # Handle leaf-level module groups
949- for name , submodule in module .named_modules ():
950- if isinstance (submodule , _SUPPORTED_PYTORCH_LAYERS ):
951- # These groups will always have parameters, so a file is expected
952- expected_files .add (get_hashed_filename (name ))
953-
954- # Handle groups for non-leaf parameters/buffers
955- modules_with_group_offloading = {
956- name for name , sm in module .named_modules () if isinstance (sm , _SUPPORTED_PYTORCH_LAYERS )
957- }
958- parameters = _gather_parameters_with_no_group_offloading_parent (module , modules_with_group_offloading )
959- buffers = _gather_buffers_with_no_group_offloading_parent (module , modules_with_group_offloading )
960-
961- all_orphans = parameters + buffers
962- if all_orphans :
963- parent_to_tensors = {}
964- module_dict = dict (module .named_modules ())
965- for tensor_name , _ in all_orphans :
966- parent_name = _find_parent_module_in_module_dict (tensor_name , module_dict )
967- if parent_name not in parent_to_tensors :
968- parent_to_tensors [parent_name ] = []
969- parent_to_tensors [parent_name ].append (tensor_name )
970-
971- for parent_name in parent_to_tensors :
972- # A file is expected for each parent that gathers orphaned tensors
973- expected_files .add (get_hashed_filename (parent_name ))
974- expected_files .add (get_hashed_filename (_GROUP_ID_LAZY_LEAF ))
975-
976- else :
977- raise ValueError (f"Unsupported offload_type: { offload_type } " )
978-
979- return expected_files
980-
981-
982- def _check_safetensors_serialization (
983- module : torch .nn .Module ,
984- offload_to_disk_path : str ,
985- offload_type : str ,
986- num_blocks_per_group : Optional [int ] = None ,
987- ) -> bool :
988- if not os .path .isdir (offload_to_disk_path ):
989- return False , None , None
990-
991- expected_files = _get_expected_safetensors_files (module , offload_to_disk_path , offload_type , num_blocks_per_group )
992- actual_files = set (glob .glob (os .path .join (offload_to_disk_path , "*.safetensors" )))
993- missing_files = expected_files - actual_files
994- extra_files = actual_files - expected_files
995-
996- is_correct = not missing_files and not extra_files
997- return is_correct , extra_files , missing_files
0 commit comments