1212# See the License for the specific language governing permissions and 
1313# limitations under the License. 
1414
15- from  contextlib  import  nullcontext 
15+ from  contextlib  import  contextmanager ,  nullcontext 
1616from  typing  import  Dict , List , Optional , Set , Tuple 
1717
1818import  torch 
@@ -56,23 +56,58 @@ def __init__(
5656        buffers : Optional [List [torch .Tensor ]] =  None ,
5757        non_blocking : bool  =  False ,
5858        stream : Optional [torch .cuda .Stream ] =  None ,
59-         cpu_param_dict :  Optional [ Dict [ torch . nn . Parameter ,  torch . Tensor ]]  =   None ,
59+         low_cpu_mem_usage = False ,
6060        onload_self : bool  =  True ,
6161    ) ->  None :
6262        self .modules  =  modules 
6363        self .offload_device  =  offload_device 
6464        self .onload_device  =  onload_device 
6565        self .offload_leader  =  offload_leader 
6666        self .onload_leader  =  onload_leader 
67-         self .parameters  =  parameters 
68-         self .buffers  =  buffers 
67+         self .parameters  =  parameters   or  [] 
68+         self .buffers  =  buffers   or  [] 
6969        self .non_blocking  =  non_blocking  or  stream  is  not None 
7070        self .stream  =  stream 
71-         self .cpu_param_dict  =  cpu_param_dict 
7271        self .onload_self  =  onload_self 
72+         self .low_cpu_mem_usage  =  low_cpu_mem_usage 
7373
74-         if  self .stream  is  not None  and  self .cpu_param_dict  is  None :
75-             raise  ValueError ("cpu_param_dict must be provided when using stream for data transfer." )
74+         self .cpu_param_dict  =  self ._init_cpu_param_dict ()
75+ 
76+     def  _init_cpu_param_dict (self ):
77+         cpu_param_dict  =  {}
78+         if  self .stream  is  None :
79+             return  cpu_param_dict 
80+ 
81+         for  module  in  self .modules :
82+             for  param  in  module .parameters ():
83+                 cpu_param_dict [param ] =  param .data .cpu () if  self .low_cpu_mem_usage  else  param .data .cpu ().pin_memory ()
84+             for  buffer  in  module .buffers ():
85+                 cpu_param_dict [buffer ] =  (
86+                     buffer .data .cpu () if  self .low_cpu_mem_usage  else  buffer .data .cpu ().pin_memory ()
87+                 )
88+ 
89+         for  param  in  self .parameters :
90+             cpu_param_dict [param ] =  param .data .cpu () if  self .low_cpu_mem_usage  else  param .data .cpu ().pin_memory ()
91+ 
92+         for  buffer  in  self .buffers :
93+             cpu_param_dict [buffer ] =  buffer .data .cpu () if  self .low_cpu_mem_usage  else  buffer .data .cpu ().pin_memory ()
94+ 
95+         return  cpu_param_dict 
96+ 
97+     @contextmanager  
98+     def  _pinned_memory_tensors (self ):
99+         pinned_dict  =  {}
100+         try :
101+             for  param , tensor  in  self .cpu_param_dict .items ():
102+                 if  not  tensor .is_pinned ():
103+                     pinned_dict [param ] =  tensor .pin_memory ()
104+                 else :
105+                     pinned_dict [param ] =  tensor 
106+ 
107+             yield  pinned_dict 
108+ 
109+         finally :
110+             pinned_dict  =  None 
76111
77112    def  onload_ (self ):
78113        r"""Onloads the group of modules to the onload_device.""" 
@@ -82,15 +117,30 @@ def onload_(self):
82117            self .stream .synchronize ()
83118
84119        with  context :
85-             for  group_module  in  self .modules :
86-                 for  param  in  group_module .parameters ():
87-                     param .data  =  param .data .to (self .onload_device , non_blocking = self .non_blocking )
88-                 for  buffer  in  group_module .buffers ():
89-                     buffer .data  =  buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
90-             if  self .parameters  is  not None :
120+             if  self .stream  is  not None :
121+                 with  self ._pinned_memory_tensors () as  pinned_memory :
122+                     for  group_module  in  self .modules :
123+                         for  param  in  group_module .parameters ():
124+                             param .data  =  pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
125+                         for  buffer  in  group_module .buffers ():
126+                             buffer .data  =  pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
127+ 
128+                     for  param  in  self .parameters :
129+                         param .data  =  pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
130+ 
131+                     for  buffer  in  self .buffers :
132+                         buffer .data  =  pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
133+ 
134+             else :
135+                 for  group_module  in  self .modules :
136+                     for  param  in  group_module .parameters ():
137+                         param .data  =  param .data .to (self .onload_device , non_blocking = self .non_blocking )
138+                     for  buffer  in  group_module .buffers ():
139+                         buffer .data  =  buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
140+ 
91141                for  param  in  self .parameters :
92142                    param .data  =  param .data .to (self .onload_device , non_blocking = self .non_blocking )
93-              if   self . buffers   is   not   None : 
143+ 
94144                for  buffer  in  self .buffers :
95145                    buffer .data  =  buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
96146
@@ -101,21 +151,18 @@ def offload_(self):
101151            for  group_module  in  self .modules :
102152                for  param  in  group_module .parameters ():
103153                    param .data  =  self .cpu_param_dict [param ]
104-             if  self .parameters  is  not None :
105-                 for  param  in  self .parameters :
106-                     param .data  =  self .cpu_param_dict [param ]
107-             if  self .buffers  is  not None :
108-                 for  buffer  in  self .buffers :
109-                     buffer .data  =  self .cpu_param_dict [buffer ]
154+             for  param  in  self .parameters :
155+                 param .data  =  self .cpu_param_dict [param ]
156+             for  buffer  in  self .buffers :
157+                 buffer .data  =  self .cpu_param_dict [buffer ]
158+ 
110159        else :
111160            for  group_module  in  self .modules :
112161                group_module .to (self .offload_device , non_blocking = self .non_blocking )
113-             if  self .parameters  is  not None :
114-                 for  param  in  self .parameters :
115-                     param .data  =  param .data .to (self .offload_device , non_blocking = self .non_blocking )
116-             if  self .buffers  is  not None :
117-                 for  buffer  in  self .buffers :
118-                     buffer .data  =  buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
162+             for  param  in  self .parameters :
163+                 param .data  =  param .data .to (self .offload_device , non_blocking = self .non_blocking )
164+             for  buffer  in  self .buffers :
165+                 buffer .data  =  buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
119166
120167
121168class  GroupOffloadingHook (ModelHook ):
@@ -284,6 +331,7 @@ def apply_group_offloading(
284331    num_blocks_per_group : Optional [int ] =  None ,
285332    non_blocking : bool  =  False ,
286333    use_stream : bool  =  False ,
334+     low_cpu_mem_usage = False ,
287335) ->  None :
288336    r""" 
289337    Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and 
@@ -365,10 +413,12 @@ def apply_group_offloading(
365413            raise  ValueError ("num_blocks_per_group must be provided when using offload_type='block_level'." )
366414
367415        _apply_group_offloading_block_level (
368-             module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream 
416+             module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream ,  low_cpu_mem_usage 
369417        )
370418    elif  offload_type  ==  "leaf_level" :
371-         _apply_group_offloading_leaf_level (module , offload_device , onload_device , non_blocking , stream )
419+         _apply_group_offloading_leaf_level (
420+             module , offload_device , onload_device , non_blocking , stream , low_cpu_mem_usage 
421+         )
372422    else :
373423        raise  ValueError (f"Unsupported offload_type: { offload_type }  )
374424
@@ -380,6 +430,7 @@ def _apply_group_offloading_block_level(
380430    onload_device : torch .device ,
381431    non_blocking : bool ,
382432    stream : Optional [torch .cuda .Stream ] =  None ,
433+     low_cpu_mem_usage : bool  =  False ,
383434) ->  None :
384435    r""" 
385436    This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to 
@@ -400,11 +451,6 @@ def _apply_group_offloading_block_level(
400451            for overlapping computation and data transfer. 
401452    """ 
402453
403-     # Create a pinned CPU parameter dict for async data transfer if streams are to be used 
404-     cpu_param_dict  =  None 
405-     if  stream  is  not None :
406-         cpu_param_dict  =  _get_pinned_cpu_param_dict (module )
407- 
408454    # Create module groups for ModuleList and Sequential blocks 
409455    modules_with_group_offloading  =  set ()
410456    unmatched_modules  =  []
@@ -425,7 +471,7 @@ def _apply_group_offloading_block_level(
425471                onload_leader = current_modules [0 ],
426472                non_blocking = non_blocking ,
427473                stream = stream ,
428-                 cpu_param_dict = cpu_param_dict ,
474+                 low_cpu_mem_usage = low_cpu_mem_usage ,
429475                onload_self = stream  is  None ,
430476            )
431477            matched_module_groups .append (group )
@@ -462,7 +508,6 @@ def _apply_group_offloading_block_level(
462508        buffers = buffers ,
463509        non_blocking = False ,
464510        stream = None ,
465-         cpu_param_dict = None ,
466511        onload_self = True ,
467512    )
468513    next_group  =  matched_module_groups [0 ] if  len (matched_module_groups ) >  0  else  None 
@@ -475,6 +520,7 @@ def _apply_group_offloading_leaf_level(
475520    onload_device : torch .device ,
476521    non_blocking : bool ,
477522    stream : Optional [torch .cuda .Stream ] =  None ,
523+     low_cpu_mem_usage : bool  =  False ,
478524) ->  None :
479525    r""" 
480526    This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory 
@@ -497,11 +543,6 @@ def _apply_group_offloading_leaf_level(
497543            for overlapping computation and data transfer. 
498544    """ 
499545
500-     # Create a pinned CPU parameter dict for async data transfer if streams are to be used 
501-     cpu_param_dict  =  None 
502-     if  stream  is  not None :
503-         cpu_param_dict  =  _get_pinned_cpu_param_dict (module )
504- 
505546    # Create module groups for leaf modules and apply group offloading hooks 
506547    modules_with_group_offloading  =  set ()
507548    for  name , submodule  in  module .named_modules ():
@@ -515,7 +556,7 @@ def _apply_group_offloading_leaf_level(
515556            onload_leader = submodule ,
516557            non_blocking = non_blocking ,
517558            stream = stream ,
518-             cpu_param_dict = cpu_param_dict ,
559+             low_cpu_mem_usage = low_cpu_mem_usage ,
519560            onload_self = True ,
520561        )
521562        _apply_group_offloading_hook (submodule , group , None )
@@ -560,7 +601,7 @@ def _apply_group_offloading_leaf_level(
560601            buffers = buffers ,
561602            non_blocking = non_blocking ,
562603            stream = stream ,
563-             cpu_param_dict = cpu_param_dict ,
604+             low_cpu_mem_usage = low_cpu_mem_usage ,
564605            onload_self = True ,
565606        )
566607        _apply_group_offloading_hook (parent_module , group , None )
@@ -579,7 +620,7 @@ def _apply_group_offloading_leaf_level(
579620            buffers = None ,
580621            non_blocking = False ,
581622            stream = None ,
582-             cpu_param_dict = None ,
623+             low_cpu_mem_usage = low_cpu_mem_usage ,
583624            onload_self = True ,
584625        )
585626        _apply_lazy_group_offloading_hook (module , unmatched_group , None )
@@ -616,17 +657,6 @@ def _apply_lazy_group_offloading_hook(
616657    registry .register_hook (lazy_prefetch_hook , _LAZY_PREFETCH_GROUP_OFFLOADING )
617658
618659
619- def  _get_pinned_cpu_param_dict (module : torch .nn .Module ) ->  Dict [torch .nn .Parameter , torch .Tensor ]:
620-     cpu_param_dict  =  {}
621-     for  param  in  module .parameters ():
622-         param .data  =  param .data .cpu ().pin_memory ()
623-         cpu_param_dict [param ] =  param .data 
624-     for  buffer  in  module .buffers ():
625-         buffer .data  =  buffer .data .cpu ().pin_memory ()
626-         cpu_param_dict [buffer ] =  buffer .data 
627-     return  cpu_param_dict 
628- 
629- 
630660def  _gather_parameters_with_no_group_offloading_parent (
631661    module : torch .nn .Module , modules_with_group_offloading : Set [str ]
632662) ->  List [torch .nn .Parameter ]:
0 commit comments