11from concurrent .futures import ThreadPoolExecutor
2+ import gc
23import time
34from typing import Optional , Union , Callable , Tuple
45import torch
56import torch .nn as nn
67
7- from library .device_utils import clean_memory_on_device
88
9+ # Keep these functions here for portability, and private to avoid confusion with the ones in device_utils.py
10+ def _clean_memory_on_device (device : torch .device ):
11+ r"""
12+ Clean memory on the specified device, will be called from training scripts.
13+ """
14+ gc .collect ()
15+
16+ # device may "cuda" or "cuda:0", so we need to check the type of device
17+ if device .type == "cuda" :
18+ torch .cuda .empty_cache ()
19+ if device .type == "xpu" :
20+ torch .xpu .empty_cache ()
21+ if device .type == "mps" :
22+ torch .mps .empty_cache ()
923
10- def synchronize_device (device : torch .device ):
24+
25+ def _synchronize_device (device : torch .device ):
1126 if device .type == "cuda" :
1227 torch .cuda .synchronize ()
1328 elif device .type == "xpu" :
@@ -71,19 +86,18 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
7186 if hasattr (module_to_cpu , "weight" ) and module_to_cpu .weight is not None :
7287 weight_swap_jobs .append ((module_to_cpu , module_to_cuda , module_to_cpu .weight .data , module_to_cuda .weight .data ))
7388
74-
7589 # device to cpu
7690 for module_to_cpu , module_to_cuda , cuda_data_view , cpu_data_view in weight_swap_jobs :
7791 module_to_cpu .weight .data = cuda_data_view .data .to ("cpu" , non_blocking = True )
7892
79- synchronize_device (device )
93+ _synchronize_device (device )
8094
8195 # cpu to device
8296 for module_to_cpu , module_to_cuda , cuda_data_view , cpu_data_view in weight_swap_jobs :
8397 cuda_data_view .copy_ (module_to_cuda .weight .data , non_blocking = True )
8498 module_to_cuda .weight .data = cuda_data_view
8599
86- synchronize_device (device )
100+ _synchronize_device (device )
87101
88102
89103def weighs_to_device (layer : nn .Module , device : torch .device ):
@@ -152,12 +166,15 @@ def _wait_blocks_move(self, block_idx):
152166# Gradient tensors
153167_grad_t = Union [tuple [torch .Tensor , ...], torch .Tensor ]
154168
169+
155170class ModelOffloader (Offloader ):
156171 """
157172 supports forward offloading
158173 """
159174
160- def __init__ (self , blocks : Union [list [nn .Module ], nn .ModuleList ], blocks_to_swap : int , device : torch .device , debug : bool = False ):
175+ def __init__ (
176+ self , blocks : Union [list [nn .Module ], nn .ModuleList ], blocks_to_swap : int , device : torch .device , debug : bool = False
177+ ):
161178 super ().__init__ (len (blocks ), blocks_to_swap , device , debug )
162179
163180 # register backward hooks
@@ -172,7 +189,9 @@ def __del__(self):
172189 for handle in self .remove_handles :
173190 handle .remove ()
174191
175- def create_backward_hook (self , blocks : Union [list [nn .Module ], nn .ModuleList ], block_index : int ) -> Optional [Callable [[nn .Module , _grad_t , _grad_t ], Union [None , _grad_t ]]]:
192+ def create_backward_hook (
193+ self , blocks : Union [list [nn .Module ], nn .ModuleList ], block_index : int
194+ ) -> Optional [Callable [[nn .Module , _grad_t , _grad_t ], Union [None , _grad_t ]]]:
176195 # -1 for 0-based index
177196 num_blocks_propagated = self .num_blocks - block_index - 1
178197 swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self .blocks_to_swap
@@ -213,8 +232,8 @@ def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn
213232 b .to (self .device ) # move block to device first
214233 weighs_to_device (b , torch .device ("cpu" )) # make sure weights are on cpu
215234
216- synchronize_device (self .device )
217- clean_memory_on_device (self .device )
235+ _synchronize_device (self .device )
236+ _clean_memory_on_device (self .device )
218237
219238 def wait_for_block (self , block_idx : int ):
220239 if self .blocks_to_swap is None or self .blocks_to_swap == 0 :
0 commit comments