Skip to content

Commit 19da454

Browse files
committed
LOCAL STUDENT - passing to adamw
1 parent 93464d1 commit 19da454

File tree

3 files changed

+83
-28
lines changed

3 files changed

+83
-28
lines changed

student_clap/config.yaml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ audio:
1313

1414
model:
1515
embedding_dim: 512
16-
# PhiNet 3 configuration (tinyCLAP: alpha=3.0, beta=0.75, t0=4, N=7 → 6.2M params)
16+
# PhiNet 3 configuration (tinyCLAP: alpha=3.0, beta=0.75, t0=4, N=7 → 4M+ params)
17+
# PhiNet NEW1 configuration (tinyCLAP: alpha=3.0, beta=0.75, t0=6, N=8 → 8M+ params)
18+
1719
phinet_alpha: 3.0
1820
phinet_beta: 0.75
19-
phinet_t0: 6
20-
phinet_N: 8
21+
phinet_t0: 4
22+
phinet_N: 7
2123
hidden_dim: 256
2224
dropout: 0.1
2325
use_gradient_checkpointing: true
@@ -32,13 +34,13 @@ model_text:
3234
training:
3335
batch_size: 1
3436
gradient_accumulation_steps: 8
35-
learning_rate: 0.003
37+
learning_rate: 0.0003
3638
epochs: 100
3739
stage2_epochs: 10
38-
stage2_learning_rate: 0.001
40+
stage2_learning_rate: 0.0001
3941
projection_only: false
40-
optimizer: "adam"
41-
weight_decay: 0.0
42+
optimizer: "adamw"
43+
weight_decay: 0.0001
4244
grad_clip: 5.0
4345
training_strategy: "both"
4446
save_every: 1

student_clap/models/student_onnx_model.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -375,23 +375,34 @@ def __init__(self, config: Dict):
375375

376376
self.model = StudentCLAPAudio(config).to(self.device).float()
377377

378-
self.optimizer = torch.optim.Adam(
379-
self.model.parameters(),
380-
lr=config['training']['learning_rate'],
381-
weight_decay=config['training']['weight_decay']
382-
)
378+
# Support configurable optimizer: 'adam' (default) or 'adamw'
379+
optimizer_type = config['training'].get('optimizer', 'adam').lower()
380+
if optimizer_type == 'adamw':
381+
self.optimizer = torch.optim.AdamW(
382+
self.model.parameters(),
383+
lr=config['training']['learning_rate'],
384+
weight_decay=config['training']['weight_decay']
385+
)
386+
logger.info("🔧 Using AdamW optimizer")
387+
else:
388+
self.optimizer = torch.optim.Adam(
389+
self.model.parameters(),
390+
lr=config['training']['learning_rate'],
391+
weight_decay=config['training']['weight_decay']
392+
)
383393

384394
self.gradient_accumulation_steps = config['training'].get('gradient_accumulation_steps', 1)
385395
self.accumulation_counter = 0
386396

397+
# Use validation-driven scheduler (mode='max' because we maximize cosine similarity)
387398
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
388399
self.optimizer,
389-
mode='min',
400+
mode='max',
390401
factor=0.1,
391402
patience=10,
392403
min_lr=1e-6
393404
)
394-
logger.info(f"📉 LR Scheduler: ReduceLROnPlateau (factor=0.1, patience=10)")
405+
logger.info(f"📉 LR Scheduler: ReduceLROnPlateau (factor=0.1, patience=10, mode=max)")
395406

396407
self.training_strategy = config['training'].get('training_strategy', 'averaged')
397408
self.segment_batch_size = config['model'].get('segment_batch_size', 10)

student_clap/train_real.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,8 @@ def train_epoch_real(trainer: StudentCLAPTrainer,
218218
avg_mse = total_mse / num_batches if num_batches > 0 else 0.0
219219
avg_cosine_sim = total_cosine_sim / num_batches if num_batches > 0 else 0.0
220220

221-
# Update learning rate scheduler with loss (ReduceLROnPlateau monitors performance)
222-
# Pass NEGATIVE cosine similarity as loss (we want to maximize similarity = minimize negative)
223-
trainer.scheduler.step(-avg_cosine_sim) # Use negative because we maximize cosine sim
221+
# Scheduler stepping is handled after validation (we want to monitor validation cosine for generalization).
222+
# Do not step scheduler here on training metric to avoid reducing LR based on training improvements.
224223
current_lr = trainer.optimizer.param_groups[0]['lr']
225224

226225
epoch_time = time.time() - epoch_start_time
@@ -465,26 +464,62 @@ def train(config_path: str, resume: str = None):
465464
logger.info(f"📂 Loading audio checkpoint: {audio_resume_path}")
466465
checkpoint = torch.load(audio_resume_path, map_location=trainer.device)
467466
trainer.model.load_state_dict(checkpoint['model_state_dict'])
468-
trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
469-
try:
470-
trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
471-
logger.info(f"✅ Scheduler state restored")
472-
except Exception as e:
473-
logger.warning(f"⚠️ Could not restore scheduler state (scheduler type changed): {e}")
474-
logger.warning(f" Creating new scheduler with patience=3, threshold=0.005, threshold_mode='rel'")
467+
468+
# Attempt to restore optimizer state; if missing or failing, keep fresh optimizer and apply config LR/WD
469+
if 'optimizer_state_dict' in checkpoint:
470+
try:
471+
trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
472+
# Ensure LR and weight_decay align with config (override to config values)
473+
new_lr = config['training']['learning_rate']
474+
new_wd = config['training']['weight_decay']
475+
for pg in trainer.optimizer.param_groups:
476+
pg['lr'] = new_lr
477+
pg['weight_decay'] = new_wd
478+
logger.info(f"✓ Optimizer restored from checkpoint and LR/WD overridden to config (lr={new_lr}, wd={new_wd})")
479+
except Exception as e:
480+
logger.warning(f"⚠️ Could not restore optimizer state cleanly: {e}; using fresh optimizer with config values")
481+
for pg in trainer.optimizer.param_groups:
482+
pg['lr'] = config['training']['learning_rate']
483+
pg['weight_decay'] = config['training']['weight_decay']
484+
else:
485+
logger.info("No optimizer state in checkpoint — using fresh optimizer (config LR/WD applied)")
486+
for pg in trainer.optimizer.param_groups:
487+
pg['lr'] = config['training']['learning_rate']
488+
pg['weight_decay'] = config['training']['weight_decay']
489+
490+
# Attempt to restore scheduler; if missing or failing, create a new one driven by validation (mode='max')
491+
if 'scheduler_state_dict' in checkpoint:
492+
try:
493+
trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
494+
logger.info("✓ Scheduler restored from checkpoint")
495+
except Exception as e:
496+
logger.warning(f"⚠️ Could not restore scheduler state: {e}")
497+
trainer.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
498+
trainer.optimizer,
499+
mode='max',
500+
factor=0.1,
501+
patience=3,
502+
threshold=0.005,
503+
threshold_mode='rel',
504+
min_lr=1e-6
505+
)
506+
logger.info("✓ Created new scheduler (mode=max) due to restore failure")
507+
else:
475508
trainer.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
476509
trainer.optimizer,
477-
mode='min',
510+
mode='max',
478511
factor=0.1,
479512
patience=3,
480513
threshold=0.005,
481514
threshold_mode='rel',
482515
min_lr=1e-6
483516
)
484-
start_epoch = checkpoint['epoch'] + 1
517+
logger.info("No scheduler state in checkpoint — created fresh scheduler (mode=max)")
518+
519+
start_epoch = checkpoint.get('epoch', 0) + 1
485520
best_val_cosine = checkpoint.get('best_val_cosine', 0.0)
486521
patience_counter = checkpoint.get('patience_counter', 0)
487-
logger.info(f"✅ Successfully resumed audio from epoch {checkpoint['epoch']}")
522+
logger.info(f"✅ Successfully resumed audio from epoch {checkpoint.get('epoch', 'N/A')}")
488523
logger.info(f" 📈 Best cosine similarity so far: {best_val_cosine:.4f}")
489524
logger.info(f" ⏰ Patience counter: {patience_counter}/{config['training']['early_stopping_patience']}")
490525
logger.info(f" 🎯 Will continue from epoch {start_epoch}")
@@ -690,7 +725,7 @@ def train(config_path: str, resume: str = None):
690725

691726
trainer.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
692727
trainer.optimizer,
693-
mode='min',
728+
mode='max',
694729
factor=0.1,
695730
patience=3,
696731
threshold=0.005,
@@ -804,6 +839,13 @@ def train(config_path: str, resume: str = None):
804839
except Exception as e:
805840
logger.warning(f"⚠️ Failed to update epoch checkpoint with validation metrics: {e}")
806841

842+
# Step scheduler on validation metric (we monitor cosine similarity - higher is better)
843+
try:
844+
trainer.scheduler.step(val_cosine)
845+
logger.info(f"Scheduler stepped using validation cosine: {val_cosine:.4f}")
846+
except Exception as e:
847+
logger.warning(f"Failed to step scheduler on validation metric: {e}")
848+
807849
# Check for improvement (use cosine similarity as main metric)
808850
if val_cosine > best_val_cosine:
809851
best_val_cosine = val_cosine

0 commit comments

Comments
 (0)