Skip to content

Commit 660ffd8

Browse files
committed
Add proper wandb online logging with safe error handling
- Replace WANDB_PROJECT env var with explicit wandb.init() - Add wandb.finish() at all exit points using contextlib.suppress - Wrap wandb.init() in try-except to prevent crashes on sync issues - Add models/ and wandb/ to .gitignore
1 parent 318068e commit 660ffd8

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,7 @@ Thumbs.db
4848
.uv/
4949
logs/
5050
results/
51+
52+
# Training artifacts
53+
models/
54+
wandb/

scripts/train_grpo.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from __future__ import annotations
2727

28+
import contextlib
2829
import json
2930
import os
3031
import shutil
@@ -35,6 +36,8 @@
3536
from dataclasses import dataclass
3637
from pathlib import Path
3738

39+
import wandb
40+
3841
# Add src and scripts to path for development
3942
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
4043
sys.path.insert(0, str(Path(__file__).parent))
@@ -318,8 +321,6 @@ def train_with_retry(config: TrainingConfig) -> int:
318321
"""Run training with retry logic."""
319322
from verifiers.rl.trainer import RLConfig, RLTrainer
320323

321-
os.environ["WANDB_PROJECT"] = config.wandb_project
322-
323324
print("=" * 60)
324325
print("Abide GRPO Training")
325326
print("=" * 60)
@@ -331,6 +332,29 @@ def train_with_retry(config: TrainingConfig) -> int:
331332
print("=" * 60)
332333
print()
333334

335+
# Initialize wandb (wrapped in try-except to avoid crashing on sync issues)
336+
wandb_enabled = False
337+
if config.use_wandb:
338+
try:
339+
wandb.init(
340+
project=config.wandb_project,
341+
name=f"grpo-{config.model_name.split('/')[-1]}",
342+
config={
343+
"model": config.model_name,
344+
"num_prompts": config.num_prompts,
345+
"rollouts_per_example": config.rollouts_per_example,
346+
"batch_size": config.batch_size,
347+
"micro_batch_size": config.micro_batch_size,
348+
"learning_rate": config.learning_rate,
349+
"max_seq_len": config.max_seq_len,
350+
},
351+
)
352+
wandb_enabled = True
353+
print("Wandb initialized successfully")
354+
except Exception as e:
355+
print(f"Warning: Failed to initialize wandb: {e}")
356+
print("Continuing without wandb logging...")
357+
334358
# Load forms
335359
forms = get_forms()
336360
print(f"Forms: {len(forms)} ({', '.join(forms.keys())})")
@@ -414,10 +438,17 @@ def train_with_retry(config: TrainingConfig) -> int:
414438
if best_path:
415439
print(f"Best model: {best_path}")
416440

441+
if wandb_enabled:
442+
with contextlib.suppress(Exception):
443+
wandb.finish()
444+
417445
return 0
418446

419447
except KeyboardInterrupt:
420448
print("\nTraining interrupted by user.")
449+
if wandb_enabled:
450+
with contextlib.suppress(Exception):
451+
wandb.finish()
421452
return 1
422453

423454
except Exception as e:
@@ -436,8 +467,14 @@ def train_with_retry(config: TrainingConfig) -> int:
436467
print(f"Will resume from {config.resume_from}")
437468
else:
438469
print("Max retries exceeded. Training failed.")
470+
if wandb_enabled:
471+
with contextlib.suppress(Exception):
472+
wandb.finish()
439473
return 1
440474

475+
if wandb_enabled:
476+
with contextlib.suppress(Exception):
477+
wandb.finish()
441478
return 1
442479

443480

0 commit comments

Comments
 (0)