1212# See the License for the specific language governing permissions and 
1313# limitations under the License. 
1414
15+ import  hashlib 
1516import  os 
1617from  contextlib  import  contextmanager , nullcontext 
1718from  typing  import  Dict , List , Optional , Set , Tuple , Union 
@@ -62,6 +63,7 @@ def __init__(
6263        low_cpu_mem_usage : bool  =  False ,
6364        onload_self : bool  =  True ,
6465        offload_to_disk_path : Optional [str ] =  None ,
66+         _group_id : Optional [int ] =  None ,
6567    ) ->  None :
6668        self .modules  =  modules 
6769        self .offload_device  =  offload_device 
@@ -80,7 +82,9 @@ def __init__(
8082        self ._is_offloaded_to_disk  =  False 
8183
8284        if  self .offload_to_disk_path :
83-             self .safetensors_file_path  =  os .path .join (self .offload_to_disk_path , f"group_{ id (self )}  )
85+             self ._group_id  =  _group_id 
86+             short_hash  =  self ._compute_group_hash (self ._group_id )
87+             self .safetensors_file_path  =  os .path .join (self .offload_to_disk_path , f"group_{ short_hash }  )
8488
8589            all_tensors  =  []
8690            for  module  in  self .modules :
@@ -260,6 +264,11 @@ def offload_(self):
260264            for  buffer  in  self .buffers :
261265                buffer .data  =  buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
262266
267+     def  _compute_group_hash (self , group_id ):
268+         hashed_id  =  hashlib .sha256 (group_id .encode ("utf-8" )).hexdigest ()
269+         # first 16 characters for a reasonably short but unique name 
270+         return  hashed_id [:16 ]
271+ 
263272
264273class  GroupOffloadingHook (ModelHook ):
265274    r""" 
@@ -603,6 +612,9 @@ def _apply_group_offloading_block_level(
603612
604613        for  i  in  range (0 , len (submodule ), num_blocks_per_group ):
605614            current_modules  =  submodule [i  : i  +  num_blocks_per_group ]
615+             start_idx  =  i 
616+             end_idx  =  i  +  len (current_modules ) -  1 
617+             group_id  =  f"{ name } { start_idx } { end_idx }  
606618            group  =  ModuleGroup (
607619                modules = current_modules ,
608620                offload_device = offload_device ,
@@ -615,6 +627,7 @@ def _apply_group_offloading_block_level(
615627                record_stream = record_stream ,
616628                low_cpu_mem_usage = low_cpu_mem_usage ,
617629                onload_self = True ,
630+                 _group_id = group_id ,
618631            )
619632            matched_module_groups .append (group )
620633            for  j  in  range (i , i  +  len (current_modules )):
@@ -649,6 +662,7 @@ def _apply_group_offloading_block_level(
649662        stream = None ,
650663        record_stream = False ,
651664        onload_self = True ,
665+         _group_id = "top_level_unmatched_modules" ,
652666    )
653667    if  stream  is  None :
654668        _apply_group_offloading_hook (module , unmatched_group , None )
@@ -715,6 +729,7 @@ def _apply_group_offloading_leaf_level(
715729            record_stream = record_stream ,
716730            low_cpu_mem_usage = low_cpu_mem_usage ,
717731            onload_self = True ,
732+             _group_id = name ,
718733        )
719734        _apply_group_offloading_hook (submodule , group , None )
720735        modules_with_group_offloading .add (name )
@@ -762,6 +777,7 @@ def _apply_group_offloading_leaf_level(
762777            record_stream = record_stream ,
763778            low_cpu_mem_usage = low_cpu_mem_usage ,
764779            onload_self = True ,
780+             _group_id = name ,
765781        )
766782        _apply_group_offloading_hook (parent_module , group , None )
767783
@@ -783,6 +799,7 @@ def _apply_group_offloading_leaf_level(
783799            record_stream = False ,
784800            low_cpu_mem_usage = low_cpu_mem_usage ,
785801            onload_self = True ,
802+             name = "lazy_leafs" ,
786803        )
787804        _apply_lazy_group_offloading_hook (module , unmatched_group , None )
788805
0 commit comments