@@ -82,52 +82,85 @@ def onload_(self):
8282 self .stream .synchronize ()
8383
8484 with context :
85- # Only transfer parameters that aren't already on the target device
85+ # Use direct per-parameter transfers rather than module-level transfers
86+ # This gives us more control and potentially better memory management
8687 for group_module in self .modules :
88+ # Check if any parameter needs moving
8789 if any (p .device != self .onload_device for p in group_module .parameters ()):
88- group_module .to (self .onload_device , non_blocking = self .non_blocking )
89-
90+ for param in group_module .parameters ():
91+ if param .device != self .onload_device :
92+ # Use direct CUDA transfer for each parameter
93+ if self .onload_device .type == "cuda" :
94+ param .data = param .data .cuda (self .onload_device .index ,
95+ non_blocking = self .non_blocking )
96+ else :
97+ param .data = param .data .to (self .onload_device ,
98+ non_blocking = self .non_blocking )
99+
100+ # Handle explicit parameters
90101 if self .parameters is not None :
91102 for param in self .parameters :
92103 if param .device != self .onload_device :
93- param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
94-
104+ # Use direct CUDA transfer for each parameter
105+ if self .onload_device .type == "cuda" :
106+ param .data = param .data .cuda (self .onload_device .index ,
107+ non_blocking = self .non_blocking )
108+ else :
109+ param .data = param .data .to (self .onload_device ,
110+ non_blocking = self .non_blocking )
111+
112+ # Handle buffers
95113 if self .buffers is not None :
96114 for buffer in self .buffers :
97115 if buffer .device != self .onload_device :
98- buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
116+ # Use direct CUDA transfer for each buffer
117+ if self .onload_device .type == "cuda" :
118+ buffer .data = buffer .data .cuda (self .onload_device .index ,
119+ non_blocking = self .non_blocking )
120+ else :
121+ buffer .data = buffer .data .to (self .onload_device ,
122+ non_blocking = self .non_blocking )
99123
100124 def offload_ (self ):
101125 r"""Offloads the group of modules to the offload_device."""
102- # Synchronize if using stream
103- if self .stream is not None :
104- torch . cuda . current_stream (). synchronize ()
105-
106- # For CPU offloading, use a method that preserves memory mapping benefits
107- if self . offload_device . type == 'cpu' :
126+ # For CPU offloading, use the most memory-efficient approach possible
127+ if self .offload_device . type == "cpu" :
128+ # Synchronize if using stream
129+ if self . stream is not None :
130+ torch . cuda . current_stream (). synchronize ()
131+
108132 # Empty GPU cache before offloading to reduce memory fragmentation
109133 if torch .cuda .is_available ():
110134 torch .cuda .empty_cache ()
111135
112- # Use to() method directly on modules
136+ # Instead of using to() on the whole module which might create copies,
137+ # directly move each parameter's data to CPU with cpu() which uses
138+ # the memory-optimized path
113139 for group_module in self .modules :
114- # Don't make copies if already on CPU
115- if any (p .device .type != 'cpu' for p in group_module .parameters ()):
116- group_module .to (self .offload_device , non_blocking = self .non_blocking )
117-
118- # Handle explicit parameters - avoid copies when already on CPU
140+ # Check if any parameter needs moving
141+ if any (p .device .type != "cpu" for p in group_module .parameters ()):
142+ for param in group_module .parameters ():
143+ if param .device .type != "cpu" :
144+ # Use direct cpu() method which is more memory-efficient than to()
145+ param .data = param .data .cpu ()
146+
147+ # Handle explicit parameters - move directly to CPU
119148 if self .parameters is not None :
120149 for param in self .parameters :
121- if param .device .type != ' cpu' :
122- # Let PyTorch handle the transfer which can preserve memory mapping
123- param .data = param .data .to ( self . offload_device , non_blocking = self . non_blocking )
150+ if param .device .type != " cpu" :
151+ # Direct CPU transfer with cpu() method
152+ param .data = param .data .cpu ( )
124153
125- # Handle buffers - avoid copies when already on CPU
154+ # Handle buffers - move directly to CPU
126155 if self .buffers is not None :
127156 for buffer in self .buffers :
128- if buffer .device .type != ' cpu' :
129- buffer .data = buffer .data .to ( self . offload_device , non_blocking = self . non_blocking )
157+ if buffer .device .type != " cpu" :
158+ buffer .data = buffer .data .cpu ( )
130159 else :
160+ # For non-CPU offloading, synchronize if using stream
161+ if self .stream is not None :
162+ torch .cuda .current_stream ().synchronize ()
163+
131164 # For non-CPU offloading, use the regular approach
132165 for group_module in self .modules :
133166 group_module .to (self .offload_device , non_blocking = self .non_blocking )
@@ -394,9 +427,7 @@ def apply_group_offloading(
394427 stream ,
395428 )
396429 elif offload_type == "leaf_level" :
397- _apply_group_offloading_leaf_level (
398- module , offload_device , onload_device , non_blocking , stream
399- )
430+ _apply_group_offloading_leaf_level (module , offload_device , onload_device , non_blocking , stream )
400431 else :
401432 raise ValueError (f"Unsupported offload_type: { offload_type } " )
402433
0 commit comments