Skip to content

Commit 7baccd4

Browse files
altria-zewei-wangalal
andauthored
Add resume for adapter_v2, enable continued finetuning for adapter (#1354)
Co-authored-by: alal <[email protected]>
1 parent db6b08d commit 7baccd4

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

litgpt/finetune/adapter_v2.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
instantiate_bnb_optimizer,
3838
instantiate_torch_optimizer,
3939
load_checkpoint,
40+
load_checkpoint_update,
4041
num_parameters,
4142
parse_devices,
4243
save_hyperparameters,
@@ -51,6 +52,7 @@ def setup(
5152
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
5253
devices: Union[int, str] = 1,
5354
num_nodes: int = 1,
55+
resume: Optional[bool] = False,
5456
data: Optional[DataModule] = None,
5557
train: TrainArgs = TrainArgs(
5658
save_interval=1000,
@@ -137,7 +139,7 @@ def setup(
137139
if torch.cuda.is_available() and devices > 1:
138140
check_nvlink_connectivity(fabric)
139141

140-
fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes)
142+
fabric.launch(main, devices, seed, config, data, resume, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes)
141143

142144

143145
def main(
@@ -146,6 +148,7 @@ def main(
146148
seed: int,
147149
config: Config,
148150
data: DataModule,
151+
resume: bool,
149152
checkpoint_dir: Path,
150153
out_dir: Path,
151154
train: TrainArgs,
@@ -191,9 +194,22 @@ def main(
191194

192195
optimizer = fabric.setup_optimizers(optimizer)
193196
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)
197+
if resume:
198+
# Finding last trace of adapter training
199+
try:
200+
resume = max(out_dir.rglob("step-*/*.pth.adapter_v2"), key=(lambda p: int(p.parent.name.split("-")[1])))
201+
fabric.print(f"Resuming training from {resume}")
202+
load_checkpoint_update(fabric, resume, model, checkpoint_path, strict=False)
203+
resume = True
204+
except ValueError:
205+
fabric.print("No previous adapter found. Finetune from start.")
206+
resume = False
207+
load_checkpoint(fabric, model, checkpoint_path, strict=False)
208+
else:
209+
# strict=False because missing keys due to Adapter weights not contained in state dict
210+
load_checkpoint(fabric, model, checkpoint_path, strict=False)
194211

195-
# strict=False because missing keys due to Adapter weights not contained in state dict
196-
load_checkpoint(fabric, model, checkpoint_path, strict=False)
212+
mark_only_adapter_v2_as_trainable(model)
197213

198214
train_time = time.perf_counter()
199215
token_counts = fit(
@@ -204,6 +220,7 @@ def main(
204220
train_dataloader=train_dataloader,
205221
val_dataloader=val_dataloader,
206222
devices=devices,
223+
resume=resume,
207224
num_nodes=num_nodes,
208225
checkpoint_dir=checkpoint_dir,
209226
out_dir=out_dir,
@@ -241,6 +258,7 @@ def fit(
241258
train_dataloader: DataLoader,
242259
val_dataloader: DataLoader,
243260
devices: int,
261+
resume: bool,
244262
checkpoint_dir: Path,
245263
out_dir: Path,
246264
train: TrainArgs,
@@ -283,7 +301,15 @@ def fit(
283301
"raw_tokens_plus_prompt_template_and_padding": torch.tensor(0, device=fabric.device, dtype=torch.long),
284302
}
285303

286-
while step_count < max_steps:
304+
if not resume:
305+
try:
306+
iter_match = max(out_dir.rglob("step-*/*.pth.adapter_v2"), key=lambda p: int(p.parent.name.split("-")[1]))
307+
step_count = int(iter_match.parent.name.split("-")[1]) if iter_match else 0
308+
except ValueError:
309+
step_count = 0
310+
311+
fabric.print(f"Starting at step count {step_count}")
312+
while step_count < max_steps and train_iterator.epoch < train.epochs:
287313
iter_num += 1
288314
iter_t0 = time.perf_counter()
289315
batch = next(train_iterator)

litgpt/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,19 @@ def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, s
391391
model.load_state_dict(state_dict, strict=strict)
392392

393393

394+
def load_checkpoint_update(
395+
fabric: L.Fabric, adapter_path: Path, model: nn.Module, checkpoint_path: Path, strict: bool = True
396+
) -> None:
397+
if isinstance(fabric.strategy, FSDPStrategy):
398+
fabric.load_raw(checkpoint_path, model, strict=strict)
399+
else:
400+
state_dict = lazy_load(checkpoint_path)
401+
state_dict = state_dict.get("model", state_dict)
402+
adapter_cp = lazy_load(adapter_path)
403+
state_dict.update(adapter_cp)
404+
model.load_state_dict(state_dict, strict=strict)
405+
406+
394407
def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:
395408
flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation
396409
# this assumes that all samples have a fixed length equal to the block size

0 commit comments

Comments
 (0)