@@ -218,9 +218,8 @@ def train_epoch_real(trainer: StudentCLAPTrainer,
218218 avg_mse = total_mse / num_batches if num_batches > 0 else 0.0
219219 avg_cosine_sim = total_cosine_sim / num_batches if num_batches > 0 else 0.0
220220
221- # Update learning rate scheduler with loss (ReduceLROnPlateau monitors performance)
222- # Pass NEGATIVE cosine similarity as loss (we want to maximize similarity = minimize negative)
223- trainer .scheduler .step (- avg_cosine_sim ) # Use negative because we maximize cosine sim
221+ # Scheduler stepping is handled after validation (we want to monitor validation cosine for generalization).
222+ # Do not step scheduler here on training metric to avoid reducing LR based on training improvements.
224223 current_lr = trainer .optimizer .param_groups [0 ]['lr' ]
225224
226225 epoch_time = time .time () - epoch_start_time
@@ -465,26 +464,62 @@ def train(config_path: str, resume: str = None):
465464 logger .info (f"📂 Loading audio checkpoint: { audio_resume_path } " )
466465 checkpoint = torch .load (audio_resume_path , map_location = trainer .device )
467466 trainer .model .load_state_dict (checkpoint ['model_state_dict' ])
468- trainer .optimizer .load_state_dict (checkpoint ['optimizer_state_dict' ])
469- try :
470- trainer .scheduler .load_state_dict (checkpoint ['scheduler_state_dict' ])
471- logger .info (f"✅ Scheduler state restored" )
472- except Exception as e :
473- logger .warning (f"⚠️ Could not restore scheduler state (scheduler type changed): { e } " )
474- logger .warning (f" Creating new scheduler with patience=3, threshold=0.005, threshold_mode='rel'" )
467+
468+ # Attempt to restore optimizer state; if missing or failing, keep fresh optimizer and apply config LR/WD
469+ if 'optimizer_state_dict' in checkpoint :
470+ try :
471+ trainer .optimizer .load_state_dict (checkpoint ['optimizer_state_dict' ])
472+ # Ensure LR and weight_decay align with config (override to config values)
473+ new_lr = config ['training' ]['learning_rate' ]
474+ new_wd = config ['training' ]['weight_decay' ]
475+ for pg in trainer .optimizer .param_groups :
476+ pg ['lr' ] = new_lr
477+ pg ['weight_decay' ] = new_wd
478+ logger .info (f"✓ Optimizer restored from checkpoint and LR/WD overridden to config (lr={ new_lr } , wd={ new_wd } )" )
479+ except Exception as e :
480+ logger .warning (f"⚠️ Could not restore optimizer state cleanly: { e } ; using fresh optimizer with config values" )
481+ for pg in trainer .optimizer .param_groups :
482+ pg ['lr' ] = config ['training' ]['learning_rate' ]
483+ pg ['weight_decay' ] = config ['training' ]['weight_decay' ]
484+ else :
485+ logger .info ("No optimizer state in checkpoint — using fresh optimizer (config LR/WD applied)" )
486+ for pg in trainer .optimizer .param_groups :
487+ pg ['lr' ] = config ['training' ]['learning_rate' ]
488+ pg ['weight_decay' ] = config ['training' ]['weight_decay' ]
489+
490+ # Attempt to restore scheduler; if missing or failing, create a new one driven by validation (mode='max')
491+ if 'scheduler_state_dict' in checkpoint :
492+ try :
493+ trainer .scheduler .load_state_dict (checkpoint ['scheduler_state_dict' ])
494+ logger .info ("✓ Scheduler restored from checkpoint" )
495+ except Exception as e :
496+ logger .warning (f"⚠️ Could not restore scheduler state: { e } " )
497+ trainer .scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
498+ trainer .optimizer ,
499+ mode = 'max' ,
500+ factor = 0.1 ,
501+ patience = 3 ,
502+ threshold = 0.005 ,
503+ threshold_mode = 'rel' ,
504+ min_lr = 1e-6
505+ )
506+ logger .info ("✓ Created new scheduler (mode=max) due to restore failure" )
507+ else :
475508 trainer .scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
476509 trainer .optimizer ,
477- mode = 'min ' ,
510+ mode = 'max ' ,
478511 factor = 0.1 ,
479512 patience = 3 ,
480513 threshold = 0.005 ,
481514 threshold_mode = 'rel' ,
482515 min_lr = 1e-6
483516 )
484- start_epoch = checkpoint ['epoch' ] + 1
517+ logger .info ("No scheduler state in checkpoint — created fresh scheduler (mode=max)" )
518+
519+ start_epoch = checkpoint .get ('epoch' , 0 ) + 1
485520 best_val_cosine = checkpoint .get ('best_val_cosine' , 0.0 )
486521 patience_counter = checkpoint .get ('patience_counter' , 0 )
487- logger .info (f"✅ Successfully resumed audio from epoch { checkpoint [ 'epoch' ] } " )
522+ logger .info (f"✅ Successfully resumed audio from epoch { checkpoint . get ( 'epoch' , 'N/A' ) } " )
488523 logger .info (f" 📈 Best cosine similarity so far: { best_val_cosine :.4f} " )
489524 logger .info (f" ⏰ Patience counter: { patience_counter } /{ config ['training' ]['early_stopping_patience' ]} " )
490525 logger .info (f" 🎯 Will continue from epoch { start_epoch } " )
@@ -690,7 +725,7 @@ def train(config_path: str, resume: str = None):
690725
691726 trainer .scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
692727 trainer .optimizer ,
693- mode = 'min ' ,
728+ mode = 'max ' ,
694729 factor = 0.1 ,
695730 patience = 3 ,
696731 threshold = 0.005 ,
@@ -804,6 +839,13 @@ def train(config_path: str, resume: str = None):
804839 except Exception as e :
805840 logger .warning (f"⚠️ Failed to update epoch checkpoint with validation metrics: { e } " )
806841
842+ # Step scheduler on validation metric (we monitor cosine similarity - higher is better)
843+ try :
844+ trainer .scheduler .step (val_cosine )
845+ logger .info (f"Scheduler stepped using validation cosine: { val_cosine :.4f} " )
846+ except Exception as e :
847+ logger .warning (f"Failed to step scheduler on validation metric: { e } " )
848+
807849 # Check for improvement (use cosine similarity as main metric)
808850 if val_cosine > best_val_cosine :
809851 best_val_cosine = val_cosine
0 commit comments