-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathcheckpointing_train.py
More file actions
executable file
·88 lines (70 loc) · 3.39 KB
/
checkpointing_train.py
File metadata and controls
executable file
·88 lines (70 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python3
"""Training script with SIGTERM-aware checkpointing.
Saves periodic checkpoints during training. On SIGTERM (e.g., torc approaching
a time limit), saves an emergency checkpoint and exits cleanly. On restart,
resumes from the latest checkpoint automatically.
Expected environment variables (set by torc or the calling shell):
TORC_JOB_NAME - used to create a per-job checkpoint directory
MODEL_INDEX - index suffix for the output model file
"""
import json
import numpy as np
import os
import pickle
import signal
import sys
import time
# ── Configuration ──────────────────────────────────────────────────
ckpt_dir = f"/workspace/checkpoints/{os.environ['TORC_JOB_NAME']}"
model_out = f"/workspace/models/model_{os.environ['MODEL_INDEX']}.pt"
os.makedirs(ckpt_dir, exist_ok=True)
total_epochs = 100
# ── SIGTERM handling ───────────────────────────────────────────────
terminated = False
def handle_sigterm(_signum, _frame):
global terminated
terminated = True
print("SIGTERM received — will save checkpoint and exit after current epoch")
signal.signal(signal.SIGTERM, handle_sigterm)
# ── Resume from checkpoint if available ────────────────────────────
checkpoints = sorted(
[f for f in os.listdir(ckpt_dir) if f.startswith("checkpoint_")],
reverse=True,
)
start_epoch = 0
weights = np.random.rand(128, 10) * 0.01
if checkpoints:
latest = os.path.join(ckpt_dir, checkpoints[0])
data = np.load(latest, allow_pickle=True).item()
weights = data["weights"]
start_epoch = data["epoch"] + 1
print(f"Resuming from checkpoint at epoch {start_epoch}")
else:
print("Starting fresh training")
# ── Load dataset ───────────────────────────────────────────────────
with open("/workspace/data/dataset.pkl", "rb") as f:
dataset = pickle.load(f)
# ── Training loop ──────────────────────────────────────────────────
loss = float("inf")
for epoch in range(start_epoch, total_epochs):
# Simulate training step
grad = np.random.randn(*weights.shape) * 0.001
weights -= grad
loss = float(np.linalg.norm(grad))
# Periodic checkpoint every 10 epochs
if (epoch + 1) % 10 == 0:
ckpt_path = os.path.join(ckpt_dir, f"checkpoint_{epoch:04d}.npy")
np.save(ckpt_path, {"weights": weights, "epoch": epoch, "loss": loss})
print(f"Epoch {epoch+1}/{total_epochs} loss={loss:.6f} [checkpoint saved]")
else:
print(f"Epoch {epoch+1}/{total_epochs} loss={loss:.6f}")
# Check if we received SIGTERM — save and exit gracefully
if terminated:
ckpt_path = os.path.join(ckpt_dir, f"checkpoint_{epoch:04d}.npy")
np.save(ckpt_path, {"weights": weights, "epoch": epoch, "loss": loss})
print(f"Emergency checkpoint saved at epoch {epoch+1}. Exiting.")
sys.exit(0)
time.sleep(1) # Simulate compute time
# Save final model
np.save(model_out, {"weights": weights, "final_loss": loss})
print(f"Training complete. Model saved to {model_out}")