@@ -32,7 +32,7 @@ def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[s
3232 queue = [model ]
3333 while queue :
3434 m = queue .pop (0 )
35- for name , child in m .named_children ():
35+ for _name , child in m .named_children ():
3636 if isinstance (child , nn .ModuleList ) and len (child ) > 0 :
3737 first_type = type (child [0 ]).__name__
3838 if "DecoderLayer" in first_type or "TransformerBlock" in first_type :
@@ -70,7 +70,9 @@ def __init__(
7070 # Find decoder layers
7171 self .layers , layer_types = _find_decoder_layers (model )
7272 if self .layers is None :
73- LOG .warning ("LayerOffloadManager: no decoder layers found, offloading disabled" )
73+ LOG .warning (
74+ "LayerOffloadManager: no decoder layers found, offloading disabled"
75+ )
7476 self .enabled = False
7577 return
7678
@@ -103,7 +105,9 @@ def __init__(
103105
104106 # CPU storage: pinned tensors for each layer's frozen params
105107 # Populated on first offload
106- self ._cpu_data : list [dict [str , torch .Tensor ]] = [{} for _ in range (self .n_layers )]
108+ self ._cpu_data : list [dict [str , torch .Tensor ]] = [
109+ {} for _ in range (self .n_layers )
110+ ]
107111
108112 # Offload all layers upfront
109113 self ._offload_all ()
@@ -146,9 +150,13 @@ def _load_layer(self, idx: int, stream=None):
146150 """Move frozen params of layer idx back to GPU."""
147151 if idx in self ._on_gpu or idx < 0 or idx >= self .n_layers :
148152 return
149- ctx = torch .cuda .stream (stream ) if stream is not None else contextlib .nullcontext ()
153+ ctx = (
154+ torch .cuda .stream (stream )
155+ if stream is not None
156+ else contextlib .nullcontext ()
157+ )
150158 with ctx :
151- for name , param in self ._frozen_params [idx ]:
159+ for _name , param in self ._frozen_params [idx ]:
152160 if param .device .type == "cuda" :
153161 continue
154162 gpu_data = param .data .to (self ._device , non_blocking = True )
@@ -183,6 +191,7 @@ def hook(module, args):
183191 # Prefetch next layer(s)
184192 for offset in range (1 , self .num_prefetch + 1 ):
185193 self ._prefetch_layer (i + offset )
194+
186195 return hook
187196
188197 def make_post_fwd (i ):
@@ -193,6 +202,7 @@ def hook(module, args, output):
193202 # Offload last layer after forward
194203 if i == self .n_layers - 1 :
195204 self ._offload_layer (i )
205+
196206 return hook
197207
198208 def make_pre_bwd (i ):
@@ -204,6 +214,7 @@ def hook(module, grad_output):
204214 # Prefetch previous layer(s)
205215 for offset in range (1 , self .num_prefetch + 1 ):
206216 self ._prefetch_layer (i - offset )
217+
207218 return hook
208219
209220 def make_post_bwd (i ):
@@ -214,6 +225,7 @@ def hook(module, grad_input, grad_output):
214225 # Offload first layer after backward
215226 if i == 0 :
216227 self ._offload_layer (i )
228+
217229 return hook
218230
219231 h1 = layer .register_forward_pre_hook (make_pre_fwd (idx ))
0 commit comments