Skip to content

Commit eedf1ca

Browse files
committed
chore: lint
1 parent 11b17a8 commit eedf1ca

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

src/axolotl/core/trainers/mixins/layer_offloading.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[s
3232
queue = [model]
3333
while queue:
3434
m = queue.pop(0)
35-
for name, child in m.named_children():
35+
for _name, child in m.named_children():
3636
if isinstance(child, nn.ModuleList) and len(child) > 0:
3737
first_type = type(child[0]).__name__
3838
if "DecoderLayer" in first_type or "TransformerBlock" in first_type:
@@ -70,7 +70,9 @@ def __init__(
7070
# Find decoder layers
7171
self.layers, layer_types = _find_decoder_layers(model)
7272
if self.layers is None:
73-
LOG.warning("LayerOffloadManager: no decoder layers found, offloading disabled")
73+
LOG.warning(
74+
"LayerOffloadManager: no decoder layers found, offloading disabled"
75+
)
7476
self.enabled = False
7577
return
7678

@@ -103,7 +105,9 @@ def __init__(
103105

104106
# CPU storage: pinned tensors for each layer's frozen params
105107
# Populated on first offload
106-
self._cpu_data: list[dict[str, torch.Tensor]] = [{} for _ in range(self.n_layers)]
108+
self._cpu_data: list[dict[str, torch.Tensor]] = [
109+
{} for _ in range(self.n_layers)
110+
]
107111

108112
# Offload all layers upfront
109113
self._offload_all()
@@ -146,9 +150,13 @@ def _load_layer(self, idx: int, stream=None):
146150
"""Move frozen params of layer idx back to GPU."""
147151
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
148152
return
149-
ctx = torch.cuda.stream(stream) if stream is not None else contextlib.nullcontext()
153+
ctx = (
154+
torch.cuda.stream(stream)
155+
if stream is not None
156+
else contextlib.nullcontext()
157+
)
150158
with ctx:
151-
for name, param in self._frozen_params[idx]:
159+
for _name, param in self._frozen_params[idx]:
152160
if param.device.type == "cuda":
153161
continue
154162
gpu_data = param.data.to(self._device, non_blocking=True)
@@ -183,6 +191,7 @@ def hook(module, args):
183191
# Prefetch next layer(s)
184192
for offset in range(1, self.num_prefetch + 1):
185193
self._prefetch_layer(i + offset)
194+
186195
return hook
187196

188197
def make_post_fwd(i):
@@ -193,6 +202,7 @@ def hook(module, args, output):
193202
# Offload last layer after forward
194203
if i == self.n_layers - 1:
195204
self._offload_layer(i)
205+
196206
return hook
197207

198208
def make_pre_bwd(i):
@@ -204,6 +214,7 @@ def hook(module, grad_output):
204214
# Prefetch previous layer(s)
205215
for offset in range(1, self.num_prefetch + 1):
206216
self._prefetch_layer(i - offset)
217+
207218
return hook
208219

209220
def make_post_bwd(i):
@@ -214,6 +225,7 @@ def hook(module, grad_input, grad_output):
214225
# Offload first layer after backward
215226
if i == 0:
216227
self._offload_layer(i)
228+
217229
return hook
218230

219231
h1 = layer.register_forward_pre_hook(make_pre_fwd(idx))

src/axolotl/core/training_args_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,9 @@ class AxolotlTrainingMixins:
237237

238238
layer_offloading: bool | None = field(
239239
default=None,
240-
metadata={"help": "Offload model layer parameters to CPU during forward, prefetch back during backward."},
240+
metadata={
241+
"help": "Offload model layer parameters to CPU during forward, prefetch back during backward."
242+
},
241243
)
242244

243245
# multi-modal section

0 commit comments

Comments
 (0)