Skip to content

Commit fd3100a

Browse files
Merge pull request #10 from jeremymanning/main
Add --resume flag for training continuation
2 parents f945450 + 4d73595 commit fd3100a

File tree

7 files changed

+245
-35
lines changed

7 files changed

+245
-35
lines changed

README.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ python generate_figures.py --figure 1a
131131
# Train models from scratch
132132
python generate_figures.py --train
133133

134+
# Resume training from existing checkpoints
135+
python generate_figures.py --train --resume
136+
134137
# List available figures
135138
python generate_figures.py --list
136139
```
@@ -175,6 +178,9 @@ fig = generate_all_losses_figure(
175178
# Using the CLI (recommended - handles all steps automatically)
176179
./run_llm_stylometry.sh --train
177180

181+
# Resume training from existing checkpoints
182+
./run_llm_stylometry.sh --train --resume
183+
178184
# Limit GPU usage if needed
179185
./run_llm_stylometry.sh --train --max-gpus 4
180186
```
@@ -184,6 +190,12 @@ This command will:
184190
2. Train all 80 models (8 authors × 10 seeds)
185191
3. Consolidate results into `data/model_results.pkl`
186192

193+
**Resume Training**: The `--resume` flag allows you to continue training from existing checkpoints:
194+
- Models that have already met training criteria are automatically skipped
195+
- Partially trained models with saved weights resume from their last checkpoint
196+
- Models without weights are trained from scratch (even if loss logs exist)
197+
- Random states are restored from checkpoints to ensure consistent training continuation
198+
187199
The training pipeline automatically handles data preparation, model training across available GPUs, and result consolidation. Individual model checkpoints and loss logs are saved in the `models/` directory.
188200

189201
### Remote Training on GPU Server
@@ -226,15 +238,21 @@ Once Git credentials are configured on your server, run `remote_train.sh` **from
226238
# From your local machine, start training on the remote GPU server
227239
./remote_train.sh
228240

241+
# Resume training from existing checkpoints
242+
./remote_train.sh --resume # or -r
243+
229244
# Kill existing training sessions and optionally start new one
230245
./remote_train.sh --kill # or -k
231246

247+
# Kill and resume (restart interrupted training)
248+
./remote_train.sh --kill --resume
249+
232250
# You'll be prompted for:
233251
# - Server address (hostname or IP)
234252
# - Username
235253
```
236254

237-
**What this script does:** The `remote_train.sh` script connects to your GPU server via SSH and executes `run_llm_stylometry.sh --train -y` in a `screen` session. This allows you to disconnect your local machine while the GPU server continues training.
255+
**What this script does:** The `remote_train.sh` script connects to your GPU server via SSH and executes `run_llm_stylometry.sh --train -y` (or `--train --resume -y` if resuming) in a `screen` session. This allows you to disconnect your local machine while the GPU server continues training.
238256

239257
The script will:
240258
1. SSH into your GPU server

code/generate_figures.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@
2121
from llm_stylometry.cli_utils import safe_print, format_header, is_windows
2222

2323

24-
def train_models(max_gpus=None, no_confirm=False):
25-
"""Train all models from scratch."""
24+
def train_models(max_gpus=None, no_confirm=False, resume=False):
25+
"""Train all models from scratch or resume from checkpoints."""
2626
safe_print("\n" + "=" * 60)
27-
safe_print("Training Models from Scratch")
27+
if resume:
28+
safe_print("Resuming Model Training from Checkpoints")
29+
else:
30+
safe_print("Training Models from Scratch")
2831
safe_print("=" * 60)
2932
warning = "[WARNING]" if is_windows() else "⚠️"
3033
# Check device availability
@@ -51,19 +54,29 @@ def train_models(max_gpus=None, no_confirm=False):
5154
safe_print("\nSkipping confirmation (--no-confirm flag set)")
5255
safe_print("Starting training...")
5356

54-
# Remove existing models directory to train from scratch
57+
# Handle models directory based on resume flag
5558
import shutil
5659
models_dir = Path('models')
57-
if models_dir.exists():
58-
safe_print("\nRemoving existing models directory...")
59-
shutil.rmtree(models_dir)
60-
safe_print("Existing models removed.")
6160

62-
# Also remove existing model results file
63-
model_results_path = Path('data/model_results.pkl')
64-
if model_results_path.exists():
65-
safe_print("Removing existing model_results.pkl...")
66-
model_results_path.unlink()
61+
if not resume:
62+
# Remove existing models directory to train from scratch
63+
if models_dir.exists():
64+
safe_print("\nRemoving existing models directory...")
65+
shutil.rmtree(models_dir)
66+
safe_print("Existing models removed.")
67+
68+
# Also remove existing model results file
69+
model_results_path = Path('data/model_results.pkl')
70+
if model_results_path.exists():
71+
safe_print("Removing existing model_results.pkl...")
72+
model_results_path.unlink()
73+
else:
74+
# When resuming, keep existing models and check their status
75+
if models_dir.exists():
76+
safe_print("\nResuming from existing models directory...")
77+
else:
78+
safe_print("\nNo existing models found. Starting fresh training...")
79+
resume = False # Fall back to fresh training if no models exist
6780

6881
# Prepare data if needed
6982
if not Path('data/cleaned').exists():
@@ -98,6 +111,9 @@ def train_models(max_gpus=None, no_confirm=False):
98111
if max_gpus:
99112
env['MAX_GPUS'] = str(max_gpus)
100113
safe_print(f"Limiting to {max_gpus} GPU(s)")
114+
# Pass through resume flag if specified
115+
if resume:
116+
env['RESUME_TRAINING'] = '1'
101117
# Run without capturing output so we can see progress
102118
result = subprocess.run([sys.executable, 'code/main.py'], env=env, check=False)
103119
if result.returncode != 0:
@@ -227,6 +243,12 @@ def main():
227243
help='Skip confirmation prompts (useful for non-interactive mode)'
228244
)
229245

246+
parser.add_argument(
247+
'--resume', '-r',
248+
action='store_true',
249+
help='Resume training from existing checkpoints (use with --train)'
250+
)
251+
230252
args = parser.parse_args()
231253

232254
if args.list:
@@ -242,9 +264,14 @@ def main():
242264

243265
safe_print(format_header("LLM Stylometry CLI", 60))
244266

267+
# Validate --resume flag usage
268+
if args.resume and not args.train:
269+
safe_print("\nWarning: --resume flag is ignored without --train flag")
270+
args.resume = False
271+
245272
# Train models if requested
246273
if args.train:
247-
if not train_models(max_gpus=args.max_gpus, no_confirm=args.no_confirm):
274+
if not train_models(max_gpus=args.max_gpus, no_confirm=args.no_confirm, resume=args.resume):
248275
return 1
249276
# Update data path to use newly generated results
250277
args.data = 'data/model_results.pkl'

code/main.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,52 @@ def tqdm(iterable, *args, **kwargs):
3636
logging.basicConfig(level=logging.INFO)
3737
logger = logging.getLogger(__name__)
3838

39+
def check_model_complete(model_name, stop_train_loss=3.0, min_epochs=0):
40+
"""
41+
Check if a model has completed training based on loss logs and weights.
42+
43+
Returns:
44+
tuple: (is_complete, has_weights, epochs_completed)
45+
- is_complete: True if model has met stop criteria
46+
- has_weights: True if model weights exist
47+
- epochs_completed: Number of epochs completed (0 if no logs)
48+
"""
49+
model_dir = MODELS_DIR / model_name
50+
51+
# Check if model weights exist
52+
weights_file = model_dir / "model.safetensors"
53+
config_file = model_dir / "config.json"
54+
training_state_file = model_dir / "training_state.pt"
55+
has_weights = weights_file.exists() and config_file.exists() and training_state_file.exists()
56+
57+
# Check loss logs
58+
loss_log_file = model_dir / "loss_logs.csv"
59+
if not loss_log_file.exists():
60+
return False, has_weights, 0
61+
62+
# Read loss logs to check training status
63+
import pandas as pd
64+
try:
65+
df = pd.read_csv(loss_log_file)
66+
if df.empty:
67+
return False, has_weights, 0
68+
69+
# Get the last training loss for this model
70+
train_losses = df[df['loss_dataset'] == 'train'].sort_values('epochs_completed')
71+
if train_losses.empty:
72+
return False, has_weights, 0
73+
74+
last_epoch = train_losses['epochs_completed'].max()
75+
last_train_loss = train_losses[train_losses['epochs_completed'] == last_epoch]['loss_value'].iloc[0]
76+
77+
# Check if model has met stop criteria
78+
is_complete = (last_train_loss <= stop_train_loss and last_epoch >= min_epochs)
79+
80+
return is_complete, has_weights, int(last_epoch)
81+
except Exception as e:
82+
logger.warning(f"Error reading loss logs for {model_name}: {e}")
83+
return False, has_weights, 0
84+
3985
# Detect available devices
4086
def get_device_info():
4187
"""Detect and return device configuration."""
@@ -51,6 +97,9 @@ def get_device_info():
5197
device_type, device_count = get_device_info()
5298
logger.info(f"Device type: {device_type}, Count: {device_count}")
5399

100+
# Check if we're in resume mode
101+
resume_mode = os.environ.get('RESUME_TRAINING', '0') == '1'
102+
54103
experiments = []
55104
for seed in range(10):
56105
for author in AUTHORS:
@@ -59,6 +108,7 @@ def get_device_info():
59108
train_author=author,
60109
seed=seed,
61110
tokenizer_name="gpt2",
111+
resume_training=resume_mode,
62112
)
63113
)
64114

@@ -298,6 +348,49 @@ def run_experiment(exp: Experiment, device_queue, device_type="cuda"):
298348
# Check if we should run sequentially (for subprocess compatibility)
299349
USE_MULTIPROCESSING = os.environ.get('NO_MULTIPROCESSING', '0') != '1'
300350

351+
# Filter experiments based on resume mode
352+
if resume_mode:
353+
logger.info("Checking existing models for resume...")
354+
experiments_to_run = []
355+
import shutil
356+
357+
for exp in experiments:
358+
is_complete, has_weights, epochs_done = check_model_complete(
359+
exp.name,
360+
exp.stop_criteria["train_loss"],
361+
exp.stop_criteria["min_epochs"]
362+
)
363+
364+
if is_complete:
365+
# Model has completed training - skip it
366+
logger.info(f"Skipping {exp.name} - already complete (epochs: {epochs_done})")
367+
elif has_weights:
368+
# Model has weights and can be resumed
369+
logger.info(f"Resuming {exp.name} from epoch {epochs_done}")
370+
experiments_to_run.append(exp)
371+
elif epochs_done > 0:
372+
# Loss logs exist but no weights (e.g., after cloning repo) - need to restart
373+
logger.info(f"Starting {exp.name} from scratch - no weights available (removing existing logs)")
374+
model_dir = MODELS_DIR / exp.name
375+
if model_dir.exists():
376+
# Remove only this specific model's directory to start fresh
377+
shutil.rmtree(model_dir)
378+
exp.resume_training = False # Force fresh start for this model
379+
experiments_to_run.append(exp)
380+
else:
381+
# No logs or weights - start fresh for this model
382+
logger.info(f"Starting fresh: {exp.name} (no existing logs or weights)")
383+
exp.resume_training = False # No checkpoint to resume from
384+
experiments_to_run.append(exp)
385+
386+
experiments = experiments_to_run
387+
total_models = 80 # 8 authors × 10 seeds
388+
logger.info(f"Models to train: {len(experiments)} out of {total_models} total")
389+
390+
if not experiments:
391+
logger.info("All models are complete. Nothing to train.")
392+
sys.exit(0)
393+
301394
# Use already detected device configuration
302395
if device_type == "cuda":
303396
# Check for MAX_GPUS environment variable to optionally limit GPU usage

code/model_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import logging
44
from torch.optim import AdamW
55
from constants import MODELS_DIR
6+
import random
7+
import numpy as np
68

79
logger = logging.getLogger(__name__)
810

@@ -18,10 +20,19 @@ def save_checkpoint(
1820

1921
model.save_pretrained(save_directory=checkpoint_dir)
2022

23+
# Save training state including random states for deterministic resume
2124
training_state = {
2225
"optimizer_state_dict": optimizer.state_dict(),
2326
"epochs_completed": epochs_completed,
27+
"random_state": random.getstate(),
28+
"np_random_state": np.random.get_state(),
29+
"torch_random_state": torch.get_rng_state(),
2430
}
31+
32+
# Also save CUDA random state if available
33+
if torch.cuda.is_available():
34+
training_state["cuda_random_state"] = torch.cuda.get_rng_state_all()
35+
2536
torch.save(obj=training_state, f=checkpoint_dir / "training_state.pt")
2637
logger.info(
2738
f"Checkpoint saved for {model_name} at epochs_completed={epochs_completed}"
@@ -42,11 +53,29 @@ def load_checkpoint(model_class, model_name, device):
4253
if not training_state_path.exists():
4354
raise FileNotFoundError(f"Training state file not found for {model_name}")
4455

45-
training_state = torch.load(f=training_state_path)
56+
training_state = torch.load(f=training_state_path, map_location=device)
4657

4758
optimizer = AdamW(params=model.parameters(), lr=0)
4859
optimizer.load_state_dict(state_dict=training_state["optimizer_state_dict"])
4960
epochs_completed = training_state["epochs_completed"]
61+
62+
# Restore random states for deterministic resume (if available)
63+
if "random_state" in training_state:
64+
random.setstate(training_state["random_state"])
65+
logger.info("Restored Python random state")
66+
67+
if "np_random_state" in training_state:
68+
np.random.set_state(training_state["np_random_state"])
69+
logger.info("Restored NumPy random state")
70+
71+
if "torch_random_state" in training_state:
72+
torch.set_rng_state(training_state["torch_random_state"])
73+
logger.info("Restored PyTorch random state")
74+
75+
if "cuda_random_state" in training_state and torch.cuda.is_available():
76+
torch.cuda.set_rng_state_all(training_state["cuda_random_state"])
77+
logger.info("Restored CUDA random state")
78+
5079
logger.info(
5180
f"Checkpoint loaded for {model_name} from epochs_completed={epochs_completed}"
5281
)

0 commit comments

Comments
 (0)