1212# See the License for the specific language governing permissions and 
1313# limitations under the License. 
1414
15- from  contextlib  import  nullcontext 
15+ from  contextlib  import  contextmanager ,  nullcontext 
1616from  typing  import  Dict , List , Optional , Set , Tuple 
1717
1818import  torch 
@@ -56,23 +56,50 @@ def __init__(
5656        buffers : Optional [List [torch .Tensor ]] =  None ,
5757        non_blocking : bool  =  False ,
5858        stream : Optional [torch .cuda .Stream ] =  None ,
59-         cpu_param_dict :  Optional [ Dict [ torch . nn . Parameter ,  torch . Tensor ]]  =   None ,
59+         low_cpu_mem_usage = False ,
6060        onload_self : bool  =  True ,
6161    ) ->  None :
6262        self .modules  =  modules 
6363        self .offload_device  =  offload_device 
6464        self .onload_device  =  onload_device 
6565        self .offload_leader  =  offload_leader 
6666        self .onload_leader  =  onload_leader 
67-         self .parameters  =  parameters 
68-         self .buffers  =  buffers 
67+         self .parameters  =  parameters   or  [] 
68+         self .buffers  =  buffers   or  [] 
6969        self .non_blocking  =  non_blocking  or  stream  is  not   None 
7070        self .stream  =  stream 
71-         self .cpu_param_dict  =  cpu_param_dict 
7271        self .onload_self  =  onload_self 
72+         self .low_cpu_mem_usage  =  low_cpu_mem_usage 
7373
74-         if  self .stream  is  not   None  and  self .cpu_param_dict  is  None :
75-             raise  ValueError ("cpu_param_dict must be provided when using stream for data transfer." )
74+         self .cpu_param_dict  =  {}
75+         for  module  in  self .modules :
76+             for  param  in  module .parameters ():
77+                 self .cpu_param_dict [param ] =  (
78+                     param .data .cpu () if  self .low_cpu_mem_usage  else  param .data .cpu ().pin_memory ()
79+                 )
80+ 
81+         for  param  in  self .parameters :
82+             self .cpu_param_dict [param ] =  param .data .cpu () if  self .low_cpu_mem_usage  else  param .data .cpu ().pin_memory ()
83+ 
84+         for  buffer  in  self .buffers :
85+             self .cpu_param_dict [buffer ] =  (
86+                 buffer .data .cpu () if  self .low_cpu_mem_usage  else  buffer .data .cpu ().pin_memory ()
87+             )
88+ 
89+     @contextmanager  
90+     def  _pinned_memory_tensors (self ):
91+         pinned_dict  =  {}
92+         try :
93+             for  param , tensor  in  self .cpu_param_dict .items ():
94+                 if  not  tensor .is_pinned ():
95+                     pinned_dict [param ] =  tensor .pin_memory ()
96+                 else :
97+                     pinned_dict [param ] =  tensor 
98+ 
99+             yield  pinned_dict 
100+ 
101+         finally :
102+             pinned_dict  =  None 
76103
77104    def  onload_ (self ):
78105        r"""Onloads the group of modules to the onload_device.""" 
@@ -82,17 +109,32 @@ def onload_(self):
82109            self .stream .synchronize ()
83110
84111        with  context :
85-             for  group_module  in  self .modules :
86-                 for  param  in  group_module .parameters ():
87-                     param .data  =  param .data .to (self .onload_device , non_blocking = self .non_blocking )
88-                 for  buffer  in  group_module .buffers ():
89-                     buffer .data  =  buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
90-             if  self .parameters  is  not   None :
91-                 for  param  in  self .parameters :
92-                     param .data  =  param .data .to (self .onload_device , non_blocking = self .non_blocking )
93-             if  self .buffers  is  not   None :
94-                 for  buffer  in  self .buffers :
95-                     buffer .data  =  buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
112+             if  self .stream  is  not   None :
113+                 with  self ._pinned_memory_tensors () as  pinned_memory :
114+                     for  group_module  in  self .modules :
115+                         for  param  in  group_module .parameters ():
116+                             param .data  =  pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
117+ 
118+                     if  self .parameters  is  not   None :
119+                         for  param  in  self .parameters :
120+                             param .data  =  pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
121+ 
122+                     if  self .buffers  is  not   None :
123+                         for  buffer  in  self .buffers :
124+                             buffer .data  =  pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
125+ 
126+             else :
127+                 for  group_module  in  self .modules :
128+                     for  param  in  group_module .parameters ():
129+                         param .data  =  param .data .to (self .onload_device , non_blocking = self .non_blocking )
130+ 
131+                 if  self .parameters  is  not   None :
132+                     for  param  in  self .parameters :
133+                         param .data  =  param .data .to (self .onload_device , non_blocking = self .non_blocking )
134+ 
135+                 if  self .buffers  is  not   None :
136+                     for  buffer  in  self .buffers :
137+                         buffer .data  =  buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
96138
97139    def  offload_ (self ):
98140        r"""Offloads the group of modules to the offload_device.""" 
@@ -108,12 +150,12 @@ def offload_(self):
108150                for  buffer  in  self .buffers :
109151                    buffer .data  =  self .cpu_param_dict [buffer ]
110152        else :
111-             for  group_module  in  self .modules :
112-                 group_module .to (self .offload_device , non_blocking = self .non_blocking )
113-             if  self .parameters   is   not   None :
153+             for  module  in  self .modules :
154+                 module .to (self .offload_device , non_blocking = self .non_blocking )
155+             if  self .parameters :
114156                for  param  in  self .parameters :
115157                    param .data  =  param .data .to (self .offload_device , non_blocking = self .non_blocking )
116-             if  self .buffers   is   not   None :
158+             if  self .buffers :
117159                for  buffer  in  self .buffers :
118160                    buffer .data  =  buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
119161
@@ -284,6 +326,7 @@ def apply_group_offloading(
284326    num_blocks_per_group : Optional [int ] =  None ,
285327    non_blocking : bool  =  False ,
286328    use_stream : bool  =  False ,
329+     low_cpu_mem_usage = False ,
287330) ->  None :
288331    r""" 
289332    Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and 
@@ -365,10 +408,12 @@ def apply_group_offloading(
365408            raise  ValueError ("num_blocks_per_group must be provided when using offload_type='block_level'." )
366409
367410        _apply_group_offloading_block_level (
368-             module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream 
411+             module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream ,  low_cpu_mem_usage 
369412        )
370413    elif  offload_type  ==  "leaf_level" :
371-         _apply_group_offloading_leaf_level (module , offload_device , onload_device , non_blocking , stream )
414+         _apply_group_offloading_leaf_level (
415+             module , offload_device , onload_device , non_blocking , stream , low_cpu_mem_usage 
416+         )
372417    else :
373418        raise  ValueError (f"Unsupported offload_type: { offload_type }  " )
374419
@@ -380,6 +425,7 @@ def _apply_group_offloading_block_level(
380425    onload_device : torch .device ,
381426    non_blocking : bool ,
382427    stream : Optional [torch .cuda .Stream ] =  None ,
428+     low_cpu_mem_usage : bool  =  False ,
383429) ->  None :
384430    r""" 
385431    This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to 
@@ -400,11 +446,6 @@ def _apply_group_offloading_block_level(
400446            for overlapping computation and data transfer. 
401447    """ 
402448
403-     # Create a pinned CPU parameter dict for async data transfer if streams are to be used 
404-     cpu_param_dict  =  None 
405-     if  stream  is  not   None :
406-         cpu_param_dict  =  _get_pinned_cpu_param_dict (module )
407- 
408449    # Create module groups for ModuleList and Sequential blocks 
409450    modules_with_group_offloading  =  set ()
410451    unmatched_modules  =  []
@@ -425,7 +466,7 @@ def _apply_group_offloading_block_level(
425466                onload_leader = current_modules [0 ],
426467                non_blocking = non_blocking ,
427468                stream = stream ,
428-                 cpu_param_dict = cpu_param_dict ,
469+                 low_cpu_mem_usage = low_cpu_mem_usage ,
429470                onload_self = stream  is  None ,
430471            )
431472            matched_module_groups .append (group )
@@ -462,7 +503,6 @@ def _apply_group_offloading_block_level(
462503        buffers = buffers ,
463504        non_blocking = False ,
464505        stream = None ,
465-         cpu_param_dict = None ,
466506        onload_self = True ,
467507    )
468508    next_group  =  matched_module_groups [0 ] if  len (matched_module_groups ) >  0  else  None 
@@ -475,6 +515,7 @@ def _apply_group_offloading_leaf_level(
475515    onload_device : torch .device ,
476516    non_blocking : bool ,
477517    stream : Optional [torch .cuda .Stream ] =  None ,
518+     low_cpu_mem_usage : bool  =  False ,
478519) ->  None :
479520    r""" 
480521    This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory 
@@ -497,11 +538,6 @@ def _apply_group_offloading_leaf_level(
497538            for overlapping computation and data transfer. 
498539    """ 
499540
500-     # Create a pinned CPU parameter dict for async data transfer if streams are to be used 
501-     cpu_param_dict  =  None 
502-     if  stream  is  not   None :
503-         cpu_param_dict  =  _get_pinned_cpu_param_dict (module )
504- 
505541    # Create module groups for leaf modules and apply group offloading hooks 
506542    modules_with_group_offloading  =  set ()
507543    for  name , submodule  in  module .named_modules ():
@@ -515,7 +551,7 @@ def _apply_group_offloading_leaf_level(
515551            onload_leader = submodule ,
516552            non_blocking = non_blocking ,
517553            stream = stream ,
518-             cpu_param_dict = cpu_param_dict ,
554+             low_cpu_mem_usage = low_cpu_mem_usage ,
519555            onload_self = True ,
520556        )
521557        _apply_group_offloading_hook (submodule , group , None )
@@ -560,7 +596,7 @@ def _apply_group_offloading_leaf_level(
560596            buffers = buffers ,
561597            non_blocking = non_blocking ,
562598            stream = stream ,
563-             cpu_param_dict = cpu_param_dict ,
599+             low_cpu_mem_usage = low_cpu_mem_usage ,
564600            onload_self = True ,
565601        )
566602        _apply_group_offloading_hook (parent_module , group , None )
@@ -579,7 +615,7 @@ def _apply_group_offloading_leaf_level(
579615            buffers = None ,
580616            non_blocking = False ,
581617            stream = None ,
582-             cpu_param_dict = None ,
618+             low_cpu_mem_usage = low_cpu_mem_usage ,
583619            onload_self = True ,
584620        )
585621        _apply_lazy_group_offloading_hook (module , unmatched_group , None )
@@ -616,17 +652,6 @@ def _apply_lazy_group_offloading_hook(
616652    registry .register_hook (lazy_prefetch_hook , _LAZY_PREFETCH_GROUP_OFFLOADING )
617653
618654
619- def  _get_pinned_cpu_param_dict (module : torch .nn .Module ) ->  Dict [torch .nn .Parameter , torch .Tensor ]:
620-     cpu_param_dict  =  {}
621-     for  param  in  module .parameters ():
622-         param .data  =  param .data .cpu ().pin_memory ()
623-         cpu_param_dict [param ] =  param .data 
624-     for  buffer  in  module .buffers ():
625-         buffer .data  =  buffer .data .cpu ().pin_memory ()
626-         cpu_param_dict [buffer ] =  buffer .data 
627-     return  cpu_param_dict 
628- 
629- 
630655def  _gather_parameters_with_no_group_offloading_parent (
631656    module : torch .nn .Module , modules_with_group_offloading : Set [str ]
632657) ->  List [torch .nn .Parameter ]:
0 commit comments