1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from contextlib import nullcontext
15+ from contextlib import contextmanager , nullcontext
1616from typing import Dict , List , Optional , Set , Tuple
1717
1818import torch
@@ -56,23 +56,58 @@ def __init__(
5656 buffers : Optional [List [torch .Tensor ]] = None ,
5757 non_blocking : bool = False ,
5858 stream : Optional [torch .cuda .Stream ] = None ,
59- cpu_param_dict : Optional [ Dict [ torch . nn . Parameter , torch . Tensor ]] = None ,
59+ low_cpu_mem_usage = False ,
6060 onload_self : bool = True ,
6161 ) -> None :
6262 self .modules = modules
6363 self .offload_device = offload_device
6464 self .onload_device = onload_device
6565 self .offload_leader = offload_leader
6666 self .onload_leader = onload_leader
67- self .parameters = parameters
68- self .buffers = buffers
67+ self .parameters = parameters or []
68+ self .buffers = buffers or []
6969 self .non_blocking = non_blocking or stream is not None
7070 self .stream = stream
71- self .cpu_param_dict = cpu_param_dict
7271 self .onload_self = onload_self
72+ self .low_cpu_mem_usage = low_cpu_mem_usage
7373
74- if self .stream is not None and self .cpu_param_dict is None :
75- raise ValueError ("cpu_param_dict must be provided when using stream for data transfer." )
74+ self .cpu_param_dict = self ._init_cpu_param_dict ()
75+
76+ def _init_cpu_param_dict (self ):
77+ cpu_param_dict = {}
78+ if self .stream is None :
79+ return cpu_param_dict
80+
81+ for module in self .modules :
82+ for param in module .parameters ():
83+ cpu_param_dict [param ] = param .data .cpu () if self .low_cpu_mem_usage else param .data .cpu ().pin_memory ()
84+ for buffer in module .buffers ():
85+ cpu_param_dict [buffer ] = (
86+ buffer .data .cpu () if self .low_cpu_mem_usage else buffer .data .cpu ().pin_memory ()
87+ )
88+
89+ for param in self .parameters :
90+ cpu_param_dict [param ] = param .data .cpu () if self .low_cpu_mem_usage else param .data .cpu ().pin_memory ()
91+
92+ for buffer in self .buffers :
93+ cpu_param_dict [buffer ] = buffer .data .cpu () if self .low_cpu_mem_usage else buffer .data .cpu ().pin_memory ()
94+
95+ return cpu_param_dict
96+
97+ @contextmanager
98+ def _pinned_memory_tensors (self ):
99+ pinned_dict = {}
100+ try :
101+ for param , tensor in self .cpu_param_dict .items ():
102+ if not tensor .is_pinned ():
103+ pinned_dict [param ] = tensor .pin_memory ()
104+ else :
105+ pinned_dict [param ] = tensor
106+
107+ yield pinned_dict
108+
109+ finally :
110+ pinned_dict = None
76111
77112 def onload_ (self ):
78113 r"""Onloads the group of modules to the onload_device."""
@@ -82,15 +117,30 @@ def onload_(self):
82117 self .stream .synchronize ()
83118
84119 with context :
85- for group_module in self .modules :
86- for param in group_module .parameters ():
87- param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
88- for buffer in group_module .buffers ():
89- buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
90- if self .parameters is not None :
120+ if self .stream is not None :
121+ with self ._pinned_memory_tensors () as pinned_memory :
122+ for group_module in self .modules :
123+ for param in group_module .parameters ():
124+ param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
125+ for buffer in group_module .buffers ():
126+ buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
127+
128+ for param in self .parameters :
129+ param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
130+
131+ for buffer in self .buffers :
132+ buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
133+
134+ else :
135+ for group_module in self .modules :
136+ for param in group_module .parameters ():
137+ param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
138+ for buffer in group_module .buffers ():
139+ buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
140+
91141 for param in self .parameters :
92142 param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
93- if self . buffers is not None :
143+
94144 for buffer in self .buffers :
95145 buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
96146
@@ -101,21 +151,18 @@ def offload_(self):
101151 for group_module in self .modules :
102152 for param in group_module .parameters ():
103153 param .data = self .cpu_param_dict [param ]
104- if self .parameters is not None :
105- for param in self .parameters :
106- param .data = self .cpu_param_dict [param ]
107- if self .buffers is not None :
108- for buffer in self .buffers :
109- buffer .data = self .cpu_param_dict [buffer ]
154+ for param in self .parameters :
155+ param .data = self .cpu_param_dict [param ]
156+ for buffer in self .buffers :
157+ buffer .data = self .cpu_param_dict [buffer ]
158+
110159 else :
111160 for group_module in self .modules :
112161 group_module .to (self .offload_device , non_blocking = self .non_blocking )
113- if self .parameters is not None :
114- for param in self .parameters :
115- param .data = param .data .to (self .offload_device , non_blocking = self .non_blocking )
116- if self .buffers is not None :
117- for buffer in self .buffers :
118- buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
162+ for param in self .parameters :
163+ param .data = param .data .to (self .offload_device , non_blocking = self .non_blocking )
164+ for buffer in self .buffers :
165+ buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
119166
120167
121168class GroupOffloadingHook (ModelHook ):
@@ -284,6 +331,7 @@ def apply_group_offloading(
284331 num_blocks_per_group : Optional [int ] = None ,
285332 non_blocking : bool = False ,
286333 use_stream : bool = False ,
334+ low_cpu_mem_usage = False ,
287335) -> None :
288336 r"""
289337 Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -365,10 +413,12 @@ def apply_group_offloading(
365413 raise ValueError ("num_blocks_per_group must be provided when using offload_type='block_level'." )
366414
367415 _apply_group_offloading_block_level (
368- module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream
416+ module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream , low_cpu_mem_usage
369417 )
370418 elif offload_type == "leaf_level" :
371- _apply_group_offloading_leaf_level (module , offload_device , onload_device , non_blocking , stream )
419+ _apply_group_offloading_leaf_level (
420+ module , offload_device , onload_device , non_blocking , stream , low_cpu_mem_usage
421+ )
372422 else :
373423 raise ValueError (f"Unsupported offload_type: { offload_type } " )
374424
@@ -380,6 +430,7 @@ def _apply_group_offloading_block_level(
380430 onload_device : torch .device ,
381431 non_blocking : bool ,
382432 stream : Optional [torch .cuda .Stream ] = None ,
433+ low_cpu_mem_usage : bool = False ,
383434) -> None :
384435 r"""
385436 This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -400,11 +451,6 @@ def _apply_group_offloading_block_level(
400451 for overlapping computation and data transfer.
401452 """
402453
403- # Create a pinned CPU parameter dict for async data transfer if streams are to be used
404- cpu_param_dict = None
405- if stream is not None :
406- cpu_param_dict = _get_pinned_cpu_param_dict (module )
407-
408454 # Create module groups for ModuleList and Sequential blocks
409455 modules_with_group_offloading = set ()
410456 unmatched_modules = []
@@ -425,7 +471,7 @@ def _apply_group_offloading_block_level(
425471 onload_leader = current_modules [0 ],
426472 non_blocking = non_blocking ,
427473 stream = stream ,
428- cpu_param_dict = cpu_param_dict ,
474+ low_cpu_mem_usage = low_cpu_mem_usage ,
429475 onload_self = stream is None ,
430476 )
431477 matched_module_groups .append (group )
@@ -462,7 +508,6 @@ def _apply_group_offloading_block_level(
462508 buffers = buffers ,
463509 non_blocking = False ,
464510 stream = None ,
465- cpu_param_dict = None ,
466511 onload_self = True ,
467512 )
468513 next_group = matched_module_groups [0 ] if len (matched_module_groups ) > 0 else None
@@ -475,6 +520,7 @@ def _apply_group_offloading_leaf_level(
475520 onload_device : torch .device ,
476521 non_blocking : bool ,
477522 stream : Optional [torch .cuda .Stream ] = None ,
523+ low_cpu_mem_usage : bool = False ,
478524) -> None :
479525 r"""
480526 This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -497,11 +543,6 @@ def _apply_group_offloading_leaf_level(
497543 for overlapping computation and data transfer.
498544 """
499545
500- # Create a pinned CPU parameter dict for async data transfer if streams are to be used
501- cpu_param_dict = None
502- if stream is not None :
503- cpu_param_dict = _get_pinned_cpu_param_dict (module )
504-
505546 # Create module groups for leaf modules and apply group offloading hooks
506547 modules_with_group_offloading = set ()
507548 for name , submodule in module .named_modules ():
@@ -515,7 +556,7 @@ def _apply_group_offloading_leaf_level(
515556 onload_leader = submodule ,
516557 non_blocking = non_blocking ,
517558 stream = stream ,
518- cpu_param_dict = cpu_param_dict ,
559+ low_cpu_mem_usage = low_cpu_mem_usage ,
519560 onload_self = True ,
520561 )
521562 _apply_group_offloading_hook (submodule , group , None )
@@ -560,7 +601,7 @@ def _apply_group_offloading_leaf_level(
560601 buffers = buffers ,
561602 non_blocking = non_blocking ,
562603 stream = stream ,
563- cpu_param_dict = cpu_param_dict ,
604+ low_cpu_mem_usage = low_cpu_mem_usage ,
564605 onload_self = True ,
565606 )
566607 _apply_group_offloading_hook (parent_module , group , None )
@@ -579,7 +620,7 @@ def _apply_group_offloading_leaf_level(
579620 buffers = None ,
580621 non_blocking = False ,
581622 stream = None ,
582- cpu_param_dict = None ,
623+ low_cpu_mem_usage = low_cpu_mem_usage ,
583624 onload_self = True ,
584625 )
585626 _apply_lazy_group_offloading_hook (module , unmatched_group , None )
@@ -616,17 +657,6 @@ def _apply_lazy_group_offloading_hook(
616657 registry .register_hook (lazy_prefetch_hook , _LAZY_PREFETCH_GROUP_OFFLOADING )
617658
618659
619- def _get_pinned_cpu_param_dict (module : torch .nn .Module ) -> Dict [torch .nn .Parameter , torch .Tensor ]:
620- cpu_param_dict = {}
621- for param in module .parameters ():
622- param .data = param .data .cpu ().pin_memory ()
623- cpu_param_dict [param ] = param .data
624- for buffer in module .buffers ():
625- buffer .data = buffer .data .cpu ().pin_memory ()
626- cpu_param_dict [buffer ] = buffer .data
627- return cpu_param_dict
628-
629-
630660def _gather_parameters_with_no_group_offloading_parent (
631661 module : torch .nn .Module , modules_with_group_offloading : Set [str ]
632662) -> List [torch .nn .Parameter ]:
0 commit comments