37
37
instantiate_bnb_optimizer ,
38
38
instantiate_torch_optimizer ,
39
39
load_checkpoint ,
40
+ load_checkpoint_update ,
40
41
num_parameters ,
41
42
parse_devices ,
42
43
save_hyperparameters ,
@@ -51,6 +52,7 @@ def setup(
51
52
quantize : Optional [Literal ["bnb.nf4" , "bnb.nf4-dq" , "bnb.fp4" , "bnb.fp4-dq" , "bnb.int8-training" ]] = None ,
52
53
devices : Union [int , str ] = 1 ,
53
54
num_nodes : int = 1 ,
55
+ resume : Optional [bool ] = False ,
54
56
data : Optional [DataModule ] = None ,
55
57
train : TrainArgs = TrainArgs (
56
58
save_interval = 1000 ,
@@ -137,7 +139,7 @@ def setup(
137
139
if torch .cuda .is_available () and devices > 1 :
138
140
check_nvlink_connectivity (fabric )
139
141
140
- fabric .launch (main , devices , seed , config , data , checkpoint_dir , out_dir , train , eval , optimizer , num_nodes )
142
+ fabric .launch (main , devices , seed , config , data , resume , checkpoint_dir , out_dir , train , eval , optimizer , num_nodes )
141
143
142
144
143
145
def main (
@@ -146,6 +148,7 @@ def main(
146
148
seed : int ,
147
149
config : Config ,
148
150
data : DataModule ,
151
+ resume : bool ,
149
152
checkpoint_dir : Path ,
150
153
out_dir : Path ,
151
154
train : TrainArgs ,
@@ -191,9 +194,22 @@ def main(
191
194
192
195
optimizer = fabric .setup_optimizers (optimizer )
193
196
scheduler = get_lr_scheduler (optimizer , warmup_steps = train .lr_warmup_steps , max_steps = lr_max_steps )
197
+ if resume :
198
+ # Finding last trace of adapter training
199
+ try :
200
+ resume = max (out_dir .rglob ("step-*/*.pth.adapter_v2" ), key = (lambda p : int (p .parent .name .split ("-" )[1 ])))
201
+ fabric .print (f"Resuming training from { resume } " )
202
+ load_checkpoint_update (fabric , resume , model , checkpoint_path , strict = False )
203
+ resume = True
204
+ except ValueError :
205
+ fabric .print ("No previous adapter found. Finetune from start." )
206
+ resume = False
207
+ load_checkpoint (fabric , model , checkpoint_path , strict = False )
208
+ else :
209
+ # strict=False because missing keys due to Adapter weights not contained in state dict
210
+ load_checkpoint (fabric , model , checkpoint_path , strict = False )
194
211
195
- # strict=False because missing keys due to Adapter weights not contained in state dict
196
- load_checkpoint (fabric , model , checkpoint_path , strict = False )
212
+ mark_only_adapter_v2_as_trainable (model )
197
213
198
214
train_time = time .perf_counter ()
199
215
token_counts = fit (
@@ -204,6 +220,7 @@ def main(
204
220
train_dataloader = train_dataloader ,
205
221
val_dataloader = val_dataloader ,
206
222
devices = devices ,
223
+ resume = resume ,
207
224
num_nodes = num_nodes ,
208
225
checkpoint_dir = checkpoint_dir ,
209
226
out_dir = out_dir ,
@@ -241,6 +258,7 @@ def fit(
241
258
train_dataloader : DataLoader ,
242
259
val_dataloader : DataLoader ,
243
260
devices : int ,
261
+ resume : bool ,
244
262
checkpoint_dir : Path ,
245
263
out_dir : Path ,
246
264
train : TrainArgs ,
@@ -283,7 +301,15 @@ def fit(
283
301
"raw_tokens_plus_prompt_template_and_padding" : torch .tensor (0 , device = fabric .device , dtype = torch .long ),
284
302
}
285
303
286
- while step_count < max_steps :
304
+ if not resume :
305
+ try :
306
+ iter_match = max (out_dir .rglob ("step-*/*.pth.adapter_v2" ), key = lambda p : int (p .parent .name .split ("-" )[1 ]))
307
+ step_count = int (iter_match .parent .name .split ("-" )[1 ]) if iter_match else 0
308
+ except ValueError :
309
+ step_count = 0
310
+
311
+ fabric .print (f"Starting at step count { step_count } " )
312
+ while step_count < max_steps and train_iterator .epoch < train .epochs :
287
313
iter_num += 1
288
314
iter_t0 = time .perf_counter ()
289
315
batch = next (train_iterator )
0 commit comments