@@ -95,7 +95,7 @@ def __init__(
95
95
self .offload_to_disk_path = offload_to_disk_path
96
96
self ._is_offloaded_to_disk = False
97
97
98
- if self .offload_to_disk_path :
98
+ if self .offload_to_disk_path is not None :
99
99
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
100
100
self .group_id = group_id if group_id is not None else str (id (self ))
101
101
short_hash = _compute_group_hash (self .group_id )
@@ -115,6 +115,12 @@ def __init__(
115
115
else :
116
116
self .cpu_param_dict = self ._init_cpu_param_dict ()
117
117
118
+ self ._torch_accelerator_module = (
119
+ getattr (torch , torch .accelerator .current_accelerator ().type )
120
+ if hasattr (torch , "accelerator" )
121
+ else torch .cuda
122
+ )
123
+
118
124
def _init_cpu_param_dict (self ):
119
125
cpu_param_dict = {}
120
126
if self .stream is None :
@@ -138,112 +144,76 @@ def _init_cpu_param_dict(self):
138
144
139
145
@contextmanager
140
146
def _pinned_memory_tensors (self ):
141
- pinned_dict = {}
142
147
try :
143
- for param , tensor in self .cpu_param_dict .items ():
144
- if not tensor .is_pinned ():
145
- pinned_dict [param ] = tensor .pin_memory ()
146
- else :
147
- pinned_dict [param ] = tensor
148
-
148
+ pinned_dict = {
149
+ param : tensor .pin_memory () if not tensor .is_pinned () else tensor
150
+ for param , tensor in self .cpu_param_dict .items ()
151
+ }
149
152
yield pinned_dict
150
-
151
153
finally :
152
154
pinned_dict = None
153
155
154
- def _transfer_tensor_to_device (self , tensor , source_tensor , current_stream = None ):
156
+ def _transfer_tensor_to_device (self , tensor , source_tensor ):
155
157
tensor .data = source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
156
- if self .record_stream and current_stream is not None :
157
- tensor .data .record_stream (current_stream )
158
+ if self .record_stream :
159
+ tensor .data .record_stream (self . _torch_accelerator_module . current_stream () )
158
160
159
- def _process_tensors_from_modules (self , pinned_memory = None , current_stream = None ):
161
+ def _process_tensors_from_modules (self , pinned_memory = None ):
160
162
for group_module in self .modules :
161
163
for param in group_module .parameters ():
162
164
source = pinned_memory [param ] if pinned_memory else param .data
163
- self ._transfer_tensor_to_device (param , source , current_stream )
165
+ self ._transfer_tensor_to_device (param , source )
164
166
for buffer in group_module .buffers ():
165
167
source = pinned_memory [buffer ] if pinned_memory else buffer .data
166
- self ._transfer_tensor_to_device (buffer , source , current_stream )
168
+ self ._transfer_tensor_to_device (buffer , source )
167
169
168
170
for param in self .parameters :
169
171
source = pinned_memory [param ] if pinned_memory else param .data
170
- self ._transfer_tensor_to_device (param , source , current_stream )
172
+ self ._transfer_tensor_to_device (param , source )
171
173
172
174
for buffer in self .buffers :
173
175
source = pinned_memory [buffer ] if pinned_memory else buffer .data
174
- self ._transfer_tensor_to_device (buffer , source , current_stream )
176
+ self ._transfer_tensor_to_device (buffer , source )
175
177
176
- def _onload_from_disk (self , current_stream ):
178
+ def _onload_from_disk (self ):
177
179
if self .stream is not None :
178
- loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
179
-
180
- for key , tensor_obj in self .key_to_tensor .items ():
181
- self .cpu_param_dict [tensor_obj ] = loaded_cpu_tensors [key ]
182
-
183
- with self ._pinned_memory_tensors () as pinned_memory :
184
- for key , tensor_obj in self .key_to_tensor .items ():
185
- self ._transfer_tensor_to_device (tensor_obj , pinned_memory [tensor_obj ], current_stream )
186
-
187
- self .cpu_param_dict .clear ()
180
+ # Wait for previous Host->Device transfer to complete
181
+ self .stream .synchronize ()
188
182
189
- else :
190
- onload_device = (
191
- self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
192
- )
193
- loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
194
- for key , tensor_obj in self .key_to_tensor .items ():
195
- tensor_obj .data = loaded_tensors [key ]
183
+ context = nullcontext () if self .stream is None else self ._torch_accelerator_module .stream (self .stream )
184
+ current_stream = self ._torch_accelerator_module .current_stream () if self .record_stream else None
196
185
197
- def _onload_from_memory (self , current_stream ):
198
- if self .stream is not None :
199
- with self ._pinned_memory_tensors () as pinned_memory :
200
- self ._process_tensors_from_modules (pinned_memory , current_stream )
201
- else :
202
- self ._process_tensors_from_modules (None , current_stream )
203
-
204
- @torch .compiler .disable ()
205
- def onload_ (self ):
206
- torch_accelerator_module = (
207
- getattr (torch , torch .accelerator .current_accelerator ().type )
208
- if hasattr (torch , "accelerator" )
209
- else torch .cuda
210
- )
211
- context = nullcontext () if self .stream is None else torch_accelerator_module .stream (self .stream )
212
- current_stream = torch_accelerator_module .current_stream () if self .record_stream else None
186
+ with context :
187
+ # Load to CPU (if using streams) or directly to target device, pin, and async copy to device
188
+ device = str (self .onload_device ) if self .stream is None else "cpu"
189
+ loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = device )
213
190
214
- if self .offload_to_disk_path :
215
191
if self .stream is not None :
216
- # Wait for previous Host->Device transfer to complete
217
- self .stream .synchronize ()
218
-
219
- with context :
220
- if self .stream is not None :
221
- # Load to CPU, pin, and async copy to device for overlapping transfer and compute
222
- loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
223
- for key , tensor_obj in self .key_to_tensor .items ():
224
- pinned_tensor = loaded_cpu_tensors [key ].pin_memory ()
225
- tensor_obj .data = pinned_tensor .to (self .onload_device , non_blocking = self .non_blocking )
226
- if self .record_stream :
227
- tensor_obj .data .record_stream (current_stream )
228
- else :
229
- # Load directly to the target device (synchronous)
230
- onload_device = (
231
- self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
232
- )
233
- loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
234
- for key , tensor_obj in self .key_to_tensor .items ():
235
- tensor_obj .data = loaded_tensors [key ]
236
- return
192
+ for key , tensor_obj in self .key_to_tensor .items ():
193
+ pinned_tensor = loaded_tensors [key ].pin_memory ()
194
+ tensor_obj .data = pinned_tensor .to (self .onload_device , non_blocking = self .non_blocking )
195
+ if self .record_stream :
196
+ tensor_obj .data .record_stream (current_stream )
197
+ else :
198
+ onload_device = (
199
+ self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
200
+ )
201
+ loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
202
+ for key , tensor_obj in self .key_to_tensor .items ():
203
+ tensor_obj .data = loaded_tensors [key ]
237
204
205
+ def _onload_from_memory (self ):
238
206
if self .stream is not None :
239
207
# Wait for previous Host->Device transfer to complete
240
208
self .stream .synchronize ()
241
209
210
+ context = nullcontext () if self .stream is None else self ._torch_accelerator_module .stream (self .stream )
242
211
with context :
243
- if self .offload_to_disk_path :
244
- self ._onload_from_disk (current_stream )
212
+ if self .stream is not None :
213
+ with self ._pinned_memory_tensors () as pinned_memory :
214
+ self ._process_tensors_from_modules (pinned_memory )
245
215
else :
246
- self ._onload_from_memory ( current_stream )
216
+ self ._process_tensors_from_modules ( None )
247
217
248
218
def _offload_to_disk (self ):
249
219
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
@@ -264,14 +234,10 @@ def _offload_to_disk(self):
264
234
tensor_obj .data = torch .empty_like (tensor_obj .data , device = self .offload_device )
265
235
266
236
def _offload_to_memory (self ):
267
- torch_accelerator_module = (
268
- getattr (torch , torch .accelerator .current_accelerator ().type )
269
- if hasattr (torch , "accelerator" )
270
- else torch .cuda
271
- )
272
237
if self .stream is not None :
273
238
if not self .record_stream :
274
- torch_accelerator_module .current_stream ().synchronize ()
239
+ self ._torch_accelerator_module .current_stream ().synchronize ()
240
+
275
241
for group_module in self .modules :
276
242
for param in group_module .parameters ():
277
243
param .data = self .cpu_param_dict [param ]
@@ -282,15 +248,23 @@ def _offload_to_memory(self):
282
248
283
249
else :
284
250
for group_module in self .modules :
285
- group_module .to (self .offload_device , non_blocking = self . non_blocking )
251
+ group_module .to (self .offload_device , non_blocking = False )
286
252
for param in self .parameters :
287
- param .data = param .data .to (self .offload_device , non_blocking = self . non_blocking )
253
+ param .data = param .data .to (self .offload_device , non_blocking = False )
288
254
for buffer in self .buffers :
289
- buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
255
+ buffer .data = buffer .data .to (self .offload_device , non_blocking = False )
256
+
257
+ @torch .compiler .disable ()
258
+ def onload_ (self ):
259
+ r"""Onloads the group of parameters to the onload_device."""
260
+ if self .offload_to_disk_path is not None :
261
+ self ._onload_from_disk ()
262
+ else :
263
+ self ._onload_from_memory ()
290
264
291
265
@torch .compiler .disable ()
292
266
def offload_ (self ):
293
- r"""Offloads the group of modules to the offload_device."""
267
+ r"""Offloads the group of parameters to the offload_device."""
294
268
if self .offload_to_disk_path :
295
269
self ._offload_to_disk ()
296
270
else :
@@ -307,11 +281,9 @@ class GroupOffloadingHook(ModelHook):
307
281
308
282
_is_stateful = False
309
283
310
- def __init__ (
311
- self , group : ModuleGroup , next_group : Optional [ModuleGroup ] = None , * , config : GroupOffloadingConfig
312
- ) -> None :
284
+ def __init__ (self , group : ModuleGroup , * , config : GroupOffloadingConfig ) -> None :
313
285
self .group = group
314
- self .next_group = next_group
286
+ self .next_group : Optional [ ModuleGroup ] = None
315
287
self .config = config
316
288
317
289
def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
@@ -459,8 +431,8 @@ def pre_forward(self, module, *args, **kwargs):
459
431
460
432
def apply_group_offloading (
461
433
module : torch .nn .Module ,
462
- onload_device : torch .device ,
463
- offload_device : torch .device = torch .device ("cpu" ),
434
+ onload_device : Union [ str , torch .device ] ,
435
+ offload_device : Union [ str , torch .device ] = torch .device ("cpu" ),
464
436
offload_type : Union [str , GroupOffloadingType ] = "block_level" ,
465
437
num_blocks_per_group : Optional [int ] = None ,
466
438
non_blocking : bool = False ,
@@ -546,6 +518,8 @@ def apply_group_offloading(
546
518
```
547
519
"""
548
520
521
+ onload_device = torch .device (onload_device ) if isinstance (onload_device , str ) else onload_device
522
+ offload_device = torch .device (offload_device ) if isinstance (offload_device , str ) else offload_device
549
523
offload_type = GroupOffloadingType (offload_type )
550
524
551
525
stream = None
@@ -633,7 +607,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
633
607
# Apply group offloading hooks to the module groups
634
608
for i , group in enumerate (matched_module_groups ):
635
609
for group_module in group .modules :
636
- _apply_group_offloading_hook (group_module , group , None , config = config )
610
+ _apply_group_offloading_hook (group_module , group , config = config )
637
611
638
612
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
639
613
# when the forward pass of this module is called. This is because the top-level module is not
@@ -662,9 +636,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
662
636
group_id = f"{ module .__class__ .__name__ } _unmatched_group" ,
663
637
)
664
638
if config .stream is None :
665
- _apply_group_offloading_hook (module , unmatched_group , None , config = config )
639
+ _apply_group_offloading_hook (module , unmatched_group , config = config )
666
640
else :
667
- _apply_lazy_group_offloading_hook (module , unmatched_group , None , config = config )
641
+ _apply_lazy_group_offloading_hook (module , unmatched_group , config = config )
668
642
669
643
670
644
def _apply_group_offloading_leaf_level (module : torch .nn .Module , config : GroupOffloadingConfig ) -> None :
@@ -693,7 +667,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
693
667
onload_self = True ,
694
668
group_id = name ,
695
669
)
696
- _apply_group_offloading_hook (submodule , group , None , config = config )
670
+ _apply_group_offloading_hook (submodule , group , config = config )
697
671
modules_with_group_offloading .add (name )
698
672
699
673
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -740,7 +714,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
740
714
onload_self = True ,
741
715
group_id = name ,
742
716
)
743
- _apply_group_offloading_hook (parent_module , group , None , config = config )
717
+ _apply_group_offloading_hook (parent_module , group , config = config )
744
718
745
719
if config .stream is not None :
746
720
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -762,13 +736,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
762
736
onload_self = True ,
763
737
group_id = _GROUP_ID_LAZY_LEAF ,
764
738
)
765
- _apply_lazy_group_offloading_hook (module , unmatched_group , None , config = config )
739
+ _apply_lazy_group_offloading_hook (module , unmatched_group , config = config )
766
740
767
741
768
742
def _apply_group_offloading_hook (
769
743
module : torch .nn .Module ,
770
744
group : ModuleGroup ,
771
- next_group : Optional [ModuleGroup ] = None ,
772
745
* ,
773
746
config : GroupOffloadingConfig ,
774
747
) -> None :
@@ -777,14 +750,13 @@ def _apply_group_offloading_hook(
777
750
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
778
751
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
779
752
if registry .get_hook (_GROUP_OFFLOADING ) is None :
780
- hook = GroupOffloadingHook (group , next_group , config = config )
753
+ hook = GroupOffloadingHook (group , config = config )
781
754
registry .register_hook (hook , _GROUP_OFFLOADING )
782
755
783
756
784
757
def _apply_lazy_group_offloading_hook (
785
758
module : torch .nn .Module ,
786
759
group : ModuleGroup ,
787
- next_group : Optional [ModuleGroup ] = None ,
788
760
* ,
789
761
config : GroupOffloadingConfig ,
790
762
) -> None :
@@ -793,7 +765,7 @@ def _apply_lazy_group_offloading_hook(
793
765
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
794
766
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
795
767
if registry .get_hook (_GROUP_OFFLOADING ) is None :
796
- hook = GroupOffloadingHook (group , next_group , config = config )
768
+ hook = GroupOffloadingHook (group , config = config )
797
769
registry .register_hook (hook , _GROUP_OFFLOADING )
798
770
799
771
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook ()
0 commit comments