Skip to content

Commit 2c1ec4c

Browse files
aggregate val_loss (#1971)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5c4c09f commit 2c1ec4c

File tree

4 files changed

+44
-8
lines changed

4 files changed

+44
-8
lines changed

litgpt/finetune/adapter.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,17 @@ def fit(
347347
val_loss = validate(fabric, model, val_dataloader, eval)
348348
generate_example(fabric, model, tokenizer, eval, data)
349349
t1 = time.perf_counter() - t0
350-
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
351-
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
350+
351+
val_loss_tensor = val_loss.detach().clone().to(fabric.device)
352+
val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32)
353+
354+
fabric.all_reduce(val_loss_tensor, reduce_op="mean")
355+
fabric.all_reduce(val_time_tensor, reduce_op="mean")
356+
357+
fabric.print(
358+
f"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms"
359+
)
360+
metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)}
352361
fabric.log_dict(metrics, step=iter_num)
353362
fabric.barrier()
354363

litgpt/finetune/adapter_v2.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,17 @@ def fit(
348348
val_loss = validate(fabric, model, val_dataloader, eval)
349349
generate_example(fabric, model, tokenizer, eval, data)
350350
t1 = time.perf_counter() - t0
351-
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
352-
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
351+
352+
val_loss_tensor = val_loss.detach().clone().to(fabric.device)
353+
val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32)
354+
355+
fabric.all_reduce(val_loss_tensor, reduce_op="mean")
356+
fabric.all_reduce(val_time_tensor, reduce_op="mean")
357+
358+
fabric.print(
359+
f"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms"
360+
)
361+
metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)}
353362
fabric.log_dict(metrics, step=iter_num)
354363
fabric.barrier()
355364

litgpt/finetune/full.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,17 @@ def fit(
320320
val_loss = validate(fabric, model, val_dataloader, eval)
321321
generate_example(fabric, model, tokenizer, eval, data)
322322
t1 = time.perf_counter() - t0
323-
fabric.print(f"iter {state['iter_num']}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
324-
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
323+
324+
val_loss_tensor = val_loss.detach().clone().to(fabric.device)
325+
val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32)
326+
327+
fabric.all_reduce(val_loss_tensor, reduce_op="mean")
328+
fabric.all_reduce(val_time_tensor, reduce_op="mean")
329+
330+
fabric.print(
331+
f"iter {state['iter_num']}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms"
332+
)
333+
metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)}
325334
fabric.log_dict(metrics, step=state["iter_num"])
326335
fabric.barrier()
327336
if train.save_interval is not None and not is_accumulating and state["step_count"] % train.save_interval == 0:

litgpt/finetune/lora.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,17 @@ def fit(
379379
val_loss = validate(fabric, model, val_dataloader, eval)
380380
generate_example(fabric, model, tokenizer, eval, data)
381381
t1 = time.perf_counter() - t0
382-
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
383-
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
382+
383+
val_loss_tensor = val_loss.detach().clone().to(fabric.device)
384+
val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32)
385+
386+
fabric.all_reduce(val_loss_tensor, reduce_op="mean")
387+
fabric.all_reduce(val_time_tensor, reduce_op="mean")
388+
389+
fabric.print(
390+
f"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms"
391+
)
392+
metrics = {"val_loss": val_loss_tensor, "val_ppl": math.exp(val_loss_tensor)}
384393
fabric.log_dict(metrics, step=iter_num)
385394
fabric.barrier()
386395

0 commit comments

Comments
 (0)