@@ -73,7 +73,18 @@ def __init__(
7373
7474 self .cpu_param_dict = {}
7575 for module in self .modules :
76- self .cpu_param_dict .update (_get_cpu_param_dict (module , self .low_cpu_mem_usage ))
76+ for param in module .parameters ():
77+ self .cpu_param_dict [param ] = (
78+ param .data .cpu () if self .low_cpu_mem_usage else param .data .cpu ().pin_memory ()
79+ )
80+
81+ for param in self .parameters :
82+ self .cpu_param_dict [param ] = param .data .cpu () if self .low_cpu_mem_usage else param .data .cpu ().pin_memory ()
83+
84+ for buffer in self .buffers :
85+ self .cpu_param_dict [buffer ] = (
86+ buffer .data .cpu () if self .low_cpu_mem_usage else buffer .data .cpu ().pin_memory ()
87+ )
7788
7889 @contextmanager
7990 def _pinned_memory_tensors (self ):
@@ -100,20 +111,30 @@ def onload_(self):
100111 with context :
101112 if self .stream is not None :
102113 with self ._pinned_memory_tensors () as pinned_memory :
103- for module in self .modules :
104- for param in module .parameters ():
114+ for group_module in self .modules :
115+ for param in group_module .parameters ():
105116 param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
117+
118+ if self .parameters is not None :
119+ for param in self .parameters :
120+ param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
121+
122+ if self .buffers is not None :
123+ for buffer in self .buffers :
124+ buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
125+
106126 else :
107127 for group_module in self .modules :
108128 for param in group_module .parameters ():
109129 param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
110130
111- if self .parameters is not None :
112- for param in self .parameters :
113- param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
114- if self .buffers is not None :
115- for buffer in self .buffers :
116- buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
131+ if self .parameters is not None :
132+ for param in self .parameters :
133+ param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
134+
135+ if self .buffers is not None :
136+ for buffer in self .buffers :
137+ buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
117138
118139 def offload_ (self ):
119140 r"""Offloads the group of modules to the offload_device."""
@@ -631,7 +652,7 @@ def _apply_lazy_group_offloading_hook(
631652 registry .register_hook (lazy_prefetch_hook , _LAZY_PREFETCH_GROUP_OFFLOADING )
632653
633654
634- def _get_cpu_param_dict (
655+ def _assign_cpu_param_dict (
635656 module : torch .nn .Module , low_cpu_mem_usage : bool = False
636657) -> Dict [torch .nn .Parameter , torch .Tensor ]:
637658 cpu_param_dict = {}
0 commit comments