Skip to content

Commit bd8cec9

Browse files
authored
Run evaluation at end of training (#1332)
1 parent 5d96660 commit bd8cec9

File tree

8 files changed

+83
-42
lines changed

8 files changed

+83
-42
lines changed

litgpt/finetune/adapter.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ def main(
178178
if fabric.device.type == "cuda":
179179
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
180180

181+
# Final evaluation
182+
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
183+
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
184+
fabric.log_dict(metrics)
185+
fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}")
186+
181187
# Save the final Adapter checkpoint at the end of training
182188
save_path = out_dir / "final" / "lit_model.pth.adapter"
183189
save_path.parent.mkdir(parents=True, exist_ok=True)
@@ -211,7 +217,7 @@ def fit(
211217
f" {model.max_seq_length} and context length is {model.config.block_size}"
212218
)
213219

214-
validate(fabric, model, val_dataloader, tokenizer, dataclasses.replace(eval, max_iters=2), data) # sanity check
220+
validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2)) # sanity check
215221

216222
train_iterator = CycleIterator(train_dataloader)
217223
throughput = ThroughputMonitor(fabric, window_size=50)
@@ -278,7 +284,8 @@ def fit(
278284

279285
if not is_accumulating and step_count % eval.interval == 0:
280286
t0 = time.perf_counter()
281-
val_loss = validate(fabric, model, val_dataloader, tokenizer, eval, data)
287+
val_loss = validate(fabric, model, val_dataloader, eval)
288+
generate_example(fabric, model, tokenizer, eval, data)
282289
t1 = time.perf_counter() - t0
283290
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
284291
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
@@ -295,11 +302,8 @@ def fit(
295302
save_prompt_style(data.prompt_style, checkpoint_file.parent)
296303

297304

298-
# the adapter "kv cache" cannot be initialized under `inference_mode`
299305
@torch.no_grad()
300-
def validate(
301-
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule
302-
) -> torch.Tensor:
306+
def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs) -> torch.Tensor:
303307
fabric.print("Validating ...")
304308
model.eval()
305309
losses = torch.zeros(min(len(val_dataloader), eval.max_iters))
@@ -311,25 +315,30 @@ def validate(
311315
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
312316

313317
val_loss = losses.mean()
318+
model.train()
319+
return val_loss
314320

315-
# produce an example:
321+
322+
# the adapter "kv cache" cannot be initialized under `inference_mode`
323+
@torch.no_grad()
324+
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
316325
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
317326
fabric.print(instruction)
318327
prompt = data.prompt_style.apply(instruction)
319328
encoded = tokenizer.encode(prompt, device=fabric.device)
329+
model.eval()
330+
320331
with fabric.init_tensor():
321332
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
322333
model.set_kv_cache(batch_size=1)
323334
output = generate(
324335
model, encoded, max_returned_tokens=len(encoded) + eval.max_new_tokens, temperature=0.8, eos_id=tokenizer.eos_id
325336
)
326337
model.clear_kv_cache()
338+
model.train()
327339
output = tokenizer.decode(output)
328340
fabric.print(output)
329341

330-
model.train()
331-
return val_loss
332-
333342

334343
def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
335344
# linear warmup followed by cosine annealing

litgpt/finetune/adapter_v2.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ def main(
178178
if fabric.device.type == "cuda":
179179
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
180180

181+
# Final evaluation
182+
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
183+
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
184+
fabric.log_dict(metrics)
185+
fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}")
186+
181187
# Save the final Adapter checkpoint at the end of training
182188
save_path = out_dir / "final" / "lit_model.pth.adapter_v2"
183189
save_path.parent.mkdir(parents=True, exist_ok=True)
@@ -211,7 +217,7 @@ def fit(
211217
f" {model.max_seq_length} and context length is {model.config.block_size}"
212218
)
213219

214-
validate(fabric, model, val_dataloader, tokenizer, dataclasses.replace(eval, max_iters=2), data) # sanity check
220+
validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2)) # sanity check
215221

216222
train_iterator = CycleIterator(train_dataloader)
217223
throughput = ThroughputMonitor(fabric, window_size=50)
@@ -278,7 +284,8 @@ def fit(
278284

279285
if not is_accumulating and step_count % eval.interval == 0:
280286
t0 = time.perf_counter()
281-
val_loss = validate(fabric, model, val_dataloader, tokenizer, eval, data)
287+
val_loss = validate(fabric, model, val_dataloader, eval)
288+
generate_example(fabric, model, tokenizer, eval, data)
282289
t1 = time.perf_counter() - t0
283290
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
284291
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
@@ -295,11 +302,8 @@ def fit(
295302
save_prompt_style(data.prompt_style, checkpoint_file.parent)
296303

297304

298-
# the adapter "kv cache" cannot be initialized under `inference_mode`
299305
@torch.no_grad()
300-
def validate(
301-
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule
302-
) -> torch.Tensor:
306+
def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs) -> torch.Tensor:
303307
fabric.print("Validating ...")
304308
model.eval()
305309
losses = torch.zeros(min(len(val_dataloader), eval.max_iters))
@@ -311,25 +315,30 @@ def validate(
311315
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
312316

313317
val_loss = losses.mean()
318+
model.train()
319+
return val_loss
314320

315-
# produce an example:
321+
322+
# the adapter "kv cache" cannot be initialized under `inference_mode`
323+
@torch.no_grad()
324+
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
316325
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
317326
fabric.print(instruction)
318327
prompt = data.prompt_style.apply(instruction)
319328
encoded = tokenizer.encode(prompt, device=fabric.device)
329+
model.eval()
330+
320331
with fabric.init_tensor():
321332
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
322333
model.set_kv_cache(batch_size=1)
323334
output = generate(
324335
model, encoded, max_returned_tokens=len(encoded) + eval.max_new_tokens, temperature=0.8, eos_id=tokenizer.eos_id
325336
)
326337
model.clear_kv_cache()
338+
model.train()
327339
output = tokenizer.decode(output)
328340
fabric.print(output)
329341

330-
model.train()
331-
return val_loss
332-
333342

334343
def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
335344
# linear warmup followed by cosine annealing

litgpt/finetune/full.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ def main(
150150
if fabric.device.type == "cuda":
151151
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
152152

153+
# Final evaluation
154+
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
155+
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
156+
fabric.log_dict(metrics, step=state["iter_num"])
157+
fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}")
158+
153159
# Save the final checkpoint at the end of training
154160
save_path = out_dir / "final" / "lit_model.pth"
155161
save_path.parent.mkdir(parents=True, exist_ok=True)
@@ -185,7 +191,7 @@ def fit(
185191
f" {model.max_seq_length} and context length is {model.config.block_size}"
186192
)
187193

188-
validate(fabric, model, val_dataloader, tokenizer, dataclasses.replace(eval, max_iters=2), data) # sanity check
194+
validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2)) # sanity check
189195
initial_iter = state["iter_num"]
190196
max_steps = train.max_steps or float("inf")
191197
train_iterator = CycleIterator(train_dataloader)
@@ -258,7 +264,8 @@ def fit(
258264

259265
if not is_accumulating and state["step_count"] % eval.interval == 0:
260266
t0 = time.perf_counter()
261-
val_loss = validate(fabric, model, val_dataloader, tokenizer, eval, data)
267+
val_loss = validate(fabric, model, val_dataloader, eval)
268+
generate_example(fabric, model, tokenizer, eval, data)
262269
t1 = time.perf_counter() - t0
263270
fabric.print(f"iter {state['iter_num']}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
264271
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
@@ -277,9 +284,7 @@ def fit(
277284

278285
# FSDP has issues with `inference_mode`
279286
@torch.no_grad()
280-
def validate(
281-
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule
282-
) -> torch.Tensor:
287+
def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs) -> torch.Tensor:
283288
fabric.print("Validating ...")
284289
model.eval()
285290
losses = torch.zeros(min(len(val_dataloader), eval.max_iters))
@@ -291,25 +296,29 @@ def validate(
291296
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
292297

293298
val_loss = losses.mean()
299+
model.train()
300+
return val_loss
294301

295-
# produce an example:
302+
303+
@torch.no_grad()
304+
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
296305
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
297306
fabric.print(instruction)
298307
prompt = data.prompt_style.apply(instruction)
299308
encoded = tokenizer.encode(prompt, device=fabric.device)
309+
model.eval()
310+
300311
with fabric.init_tensor():
301312
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
302313
model.set_kv_cache(batch_size=1)
303314
output = generate(
304315
model, encoded, max_returned_tokens=len(encoded) + eval.max_new_tokens, temperature=0.8, eos_id=tokenizer.eos_id
305316
)
306317
model.clear_kv_cache()
318+
model.train()
307319
output = tokenizer.decode(output)
308320
fabric.print(output)
309321

310-
model.train()
311-
return val_loss
312-
313322

314323
def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
315324
# linear warmup followed by cosine annealing

litgpt/finetune/lora.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,12 @@ def main(
208208
if fabric.device.type == "cuda":
209209
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
210210

211+
# Final evaluation
212+
val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))
213+
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
214+
fabric.log_dict(metrics)
215+
fabric.print(f"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}")
216+
211217
# Save the final LoRA checkpoint at the end of training
212218
save_path = out_dir / "final" / "lit_model.pth.lora"
213219
save_path.parent.mkdir(parents=True, exist_ok=True)
@@ -242,7 +248,7 @@ def fit(
242248
f" {model.max_seq_length} and context length is {model.config.block_size}"
243249
)
244250

245-
validate(fabric, model, val_dataloader, tokenizer, dataclasses.replace(eval, max_iters=2), data) # sanity check
251+
validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2)) # sanity check
246252

247253
train_iterator = CycleIterator(train_dataloader)
248254
throughput = ThroughputMonitor(fabric, window_size=50)
@@ -309,7 +315,8 @@ def fit(
309315

310316
if not is_accumulating and step_count % eval.interval == 0:
311317
t0 = time.perf_counter()
312-
val_loss = validate(fabric, model, val_dataloader, tokenizer, eval, data)
318+
val_loss = validate(fabric, model, val_dataloader, eval)
319+
generate_example(fabric, model, tokenizer, eval, data)
313320
t1 = time.perf_counter() - t0
314321
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
315322
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
@@ -328,9 +335,7 @@ def fit(
328335

329336
# FSDP has issues with `inference_mode`
330337
@torch.no_grad()
331-
def validate(
332-
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule
333-
) -> torch.Tensor:
338+
def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs) -> torch.Tensor:
334339
fabric.print("Validating ...")
335340
model.eval()
336341
losses = torch.zeros(min(len(val_dataloader), eval.max_iters))
@@ -343,24 +348,29 @@ def validate(
343348

344349
val_loss = losses.mean()
345350

346-
# produce an example:
351+
model.train()
352+
return val_loss
353+
354+
355+
@torch.no_grad()
356+
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
347357
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
348358
fabric.print(instruction)
349359
prompt = data.prompt_style.apply(instruction)
350360
encoded = tokenizer.encode(prompt, device=fabric.device)
361+
model.eval()
362+
351363
with fabric.init_tensor():
352364
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
353365
model.set_kv_cache(batch_size=1)
354366
output = generate(
355367
model, encoded, max_returned_tokens=len(encoded) + eval.max_new_tokens, temperature=0.8, eos_id=tokenizer.eos_id
356368
)
357369
model.clear_kv_cache()
370+
model.train()
358371
output = tokenizer.decode(output)
359372
fabric.print(output)
360373

361-
model.train()
362-
return val_loss
363-
364374

365375
def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
366376
# linear warmup followed by cosine annealing

tests/test_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path)
9898

9999
logs = stdout.getvalue()
100100
assert logs.count("(step)") == 6
101-
assert logs.count("val loss") == 3
101+
assert logs.count("val loss") == 4 # 3 validations + 1 final validation
102+
assert logs.count("Final evaluation") == 1
102103
assert "of trainable parameters: 168" in logs
103104

104105

tests/test_adapter_v2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_pa
115115

116116
logs = stdout.getvalue()
117117
assert logs.count("(step)") == 6
118-
assert logs.count("val loss") == 3
118+
assert logs.count("val loss") == 4 # 3 validations + 1 final validation
119+
assert logs.count("Final evaluation") == 1
119120
assert "of trainable parameters: 552" in logs
120121

121122

tests/test_full.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
5555

5656
logs = stdout.getvalue()
5757
assert logs.count("(step)") == 6
58-
assert logs.count("val loss") == 3
58+
assert logs.count("val loss") == 4 # 3 validations + 1 final validation
59+
assert logs.count("Final evaluation") == 1
5960
assert "of trainable parameters: 1,888" in logs
6061

6162
# Resume training and do 2 steps more

tests/test_lora.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
221221

222222
logs = stdout.getvalue()
223223
assert logs.count("(step)") == 6
224-
assert logs.count("val loss") == 3
224+
assert logs.count("val loss") == 4 # 3 validations + 1 final validation
225+
assert logs.count("Final evaluation") == 1
225226
assert "of trainable parameters: 512" in logs
226227

227228

0 commit comments

Comments
 (0)