2525
2626from __future__ import annotations
2727
28+ import contextlib
2829import json
2930import os
3031import shutil
3536from dataclasses import dataclass
3637from pathlib import Path
3738
39+ import wandb
40+
3841# Add src and scripts to path for development
3942sys .path .insert (0 , str (Path (__file__ ).parent .parent / "src" ))
4043sys .path .insert (0 , str (Path (__file__ ).parent ))
@@ -318,8 +321,6 @@ def train_with_retry(config: TrainingConfig) -> int:
318321 """Run training with retry logic."""
319322 from verifiers .rl .trainer import RLConfig , RLTrainer
320323
321- os .environ ["WANDB_PROJECT" ] = config .wandb_project
322-
323324 print ("=" * 60 )
324325 print ("Abide GRPO Training" )
325326 print ("=" * 60 )
@@ -331,6 +332,29 @@ def train_with_retry(config: TrainingConfig) -> int:
331332 print ("=" * 60 )
332333 print ()
333334
335+ # Initialize wandb (wrapped in try-except to avoid crashing on sync issues)
336+ wandb_enabled = False
337+ if config .use_wandb :
338+ try :
339+ wandb .init (
340+ project = config .wandb_project ,
341+ name = f"grpo-{ config .model_name .split ('/' )[- 1 ]} " ,
342+ config = {
343+ "model" : config .model_name ,
344+ "num_prompts" : config .num_prompts ,
345+ "rollouts_per_example" : config .rollouts_per_example ,
346+ "batch_size" : config .batch_size ,
347+ "micro_batch_size" : config .micro_batch_size ,
348+ "learning_rate" : config .learning_rate ,
349+ "max_seq_len" : config .max_seq_len ,
350+ },
351+ )
352+ wandb_enabled = True
353+ print ("Wandb initialized successfully" )
354+ except Exception as e :
355+ print (f"Warning: Failed to initialize wandb: { e } " )
356+ print ("Continuing without wandb logging..." )
357+
334358 # Load forms
335359 forms = get_forms ()
336360 print (f"Forms: { len (forms )} ({ ', ' .join (forms .keys ())} )" )
@@ -414,10 +438,17 @@ def train_with_retry(config: TrainingConfig) -> int:
414438 if best_path :
415439 print (f"Best model: { best_path } " )
416440
441+ if wandb_enabled :
442+ with contextlib .suppress (Exception ):
443+ wandb .finish ()
444+
417445 return 0
418446
419447 except KeyboardInterrupt :
420448 print ("\n Training interrupted by user." )
449+ if wandb_enabled :
450+ with contextlib .suppress (Exception ):
451+ wandb .finish ()
421452 return 1
422453
423454 except Exception as e :
@@ -436,8 +467,14 @@ def train_with_retry(config: TrainingConfig) -> int:
436467 print (f"Will resume from { config .resume_from } " )
437468 else :
438469 print ("Max retries exceeded. Training failed." )
470+ if wandb_enabled :
471+ with contextlib .suppress (Exception ):
472+ wandb .finish ()
439473 return 1
440474
475+ if wandb_enabled :
476+ with contextlib .suppress (Exception ):
477+ wandb .finish ()
441478 return 1
442479
443480
0 commit comments