Skip to content

Commit a6e64bb

Browse files
committed
LOCAL STUDENT - CAST REMOVED
1 parent 6d744e7 commit a6e64bb

File tree

2 files changed

+27
-91
lines changed

2 files changed

+27
-91
lines changed

student_clap/models/student_onnx_model.py

Lines changed: 22 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def process_audio_segments(self, audio_segments: torch.Tensor) -> torch.Tensor:
330330
"""
331331

332332
model_device = next(self.parameters()).device
333-
audio_segments = audio_segments.to(model_device, dtype=torch.float32)
333+
audio_segments = audio_segments.to(model_device)
334334

335335
mel_specs = self.compute_mel_spectrogram(audio_segments)
336336

@@ -356,21 +356,6 @@ def count_parameters(self) -> Dict[str, int]:
356356
}
357357

358358
class StudentCLAPTrainer:
359-
def _cast_batchnorm_to_float32(self):
360-
"""Cast all BatchNorm layers in the model to float32 for all platforms (CUDA, Mac MPS, CPU)."""
361-
for module in self.model.modules():
362-
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
363-
module.to(dtype=torch.float32)
364-
365-
def _cast_nonbatchnorm_to_dtype(self, dtype):
366-
"""Cast all non-BatchNorm layers in the model to the given dtype."""
367-
for module in self.model.modules():
368-
if not isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
369-
if hasattr(module, 'to'):
370-
try:
371-
module.to(dtype=dtype)
372-
except Exception:
373-
pass
374359
"""
375360
ONNX-compatible trainer for Student CLAP using PyTorch.
376361
@@ -382,18 +367,15 @@ def _cast_nonbatchnorm_to_dtype(self, dtype):
382367
def __init__(self, config: Dict):
383368
self.config = config
384369

385-
# --- Device and precision autodetection ---
370+
# --- Device autodetection, always use float32 ---
386371
if torch.cuda.is_available():
387372
self.device = torch.device('cuda')
388-
self.dtype = torch.bfloat16
389373
elif torch.backends.mps.is_available():
390374
self.device = torch.device('mps')
391-
self.dtype = torch.bfloat16 # Use bfloat16 for Mac (MPS)
392375
else:
393376
self.device = torch.device('cpu')
394-
self.dtype = torch.float32
395377

396-
self.model = StudentCLAPAudio(config).to(self.device, dtype=self.dtype)
378+
self.model = StudentCLAPAudio(config).to(self.device)
397379

398380
# Support configurable optimizer: 'adam' (default) or 'adamw'
399381
optimizer_type = config['training'].get('optimizer', 'adam').lower()
@@ -432,13 +414,11 @@ def __init__(self, config: Dict):
432414
logger.info("🔒 STAGE 2: Freezing encoder, training projection head only")
433415
self._freeze_encoder()
434416

435-
logger.info(f"Initialized Student CLAP trainer on {self.device} (precision: {self.dtype})")
417+
logger.info(f"Initialized Student CLAP trainer on {self.device}")
436418
logger.info(f"Model parameters: {self.model.count_parameters()}")
437419
logger.info(f"Training strategy: {self.training_strategy}")
438420

439-
@property
440-
def device_dtype(self):
441-
return self.device, self.dtype
421+
#
442422

443423
def _freeze_encoder(self):
444424
"""Freeze encoder layers, keep only projection head trainable (Stage 2)."""
@@ -474,20 +454,12 @@ def compute_loss(self,
474454
loss_dict: Individual loss components for logging
475455
"""
476456

477-
# Always cast both tensors to the same dtype as the model/device
478-
target_dtype = self.dtype
479-
if torch.cuda.is_available() and str(self.device) == 'cuda':
480-
target_dtype = torch.bfloat16
481-
elif torch.backends.mps.is_available() and str(self.device) == 'mps':
482-
target_dtype = torch.float16
483-
else:
484-
target_dtype = torch.float32
485-
457+
# Always use default float32 for all tensors
486458
if not isinstance(teacher_embeddings, torch.Tensor):
487-
teacher_embeddings = torch.from_numpy(teacher_embeddings).to(dtype=target_dtype, device=self.device)
459+
teacher_embeddings = torch.from_numpy(teacher_embeddings).to(self.device)
488460
else:
489-
teacher_embeddings = teacher_embeddings.to(dtype=target_dtype, device=self.device)
490-
student_embeddings = student_embeddings.to(dtype=target_dtype, device=self.device)
461+
teacher_embeddings = teacher_embeddings.to(self.device)
462+
student_embeddings = student_embeddings.to(self.device)
491463

492464
teacher_embeddings = F.normalize(teacher_embeddings, p=2, dim=1)
493465
student_embeddings = F.normalize(student_embeddings, p=2, dim=1)
@@ -513,32 +485,13 @@ def train_step(self, batch: Dict) -> Dict:
513485
"""
514486
Single training step on a batch.
515487
516-
if torch.cuda.is_available() and str(self.device) == 'cuda':
517-
self.model.to(self.device)
518-
self._cast_nonbatchnorm_to_dtype(torch.bfloat16)
519-
self._cast_batchnorm_to_float32()
520-
tensor_dtype = torch.bfloat16
521-
elif torch.backends.mps.is_available() and str(self.device) == 'mps':
522-
self.model.to(self.device)
523-
self._cast_nonbatchnorm_to_dtype(torch.bfloat16)
524-
self._cast_batchnorm_to_float32()
525-
tensor_dtype = torch.bfloat16
526-
else:
527-
self.model.to(self.device)
528-
self._cast_nonbatchnorm_to_dtype(torch.float32)
529-
self._cast_batchnorm_to_float32()
530-
tensor_dtype = torch.float32
488+
# Always use default float32 for training
489+
self.model.to(self.device)
531490
self.model.train()
532491
step_metrics: Dictionary with loss and performance metrics
533492
"""
534493

535-
# Only patch CUDA: force bfloat16, otherwise use autodetected self.dtype (float16 for MPS, float32 for CPU)
536-
if torch.cuda.is_available() and str(self.device) == 'cuda':
537-
self.model.to(self.device, dtype=torch.bfloat16)
538-
tensor_dtype = torch.bfloat16
539-
else:
540-
self.model.to(self.device, dtype=self.dtype)
541-
tensor_dtype = self.dtype
494+
self.model.to(self.device)
542495
self.model.train()
543496

544497
if self.accumulation_counter == 0:
@@ -553,20 +506,19 @@ def train_step(self, batch: Dict) -> Dict:
553506
batch.get('teacher_segment_embeddings', [None] * len(batch['audio_segments']))
554507
)):
555508

556-
# Only use bfloat16 for CUDA, float32 everywhere else
557-
# Move mel_segments to correct device/dtype
509+
# Always use default float32 for all input tensors
558510
if not isinstance(mel_segments, torch.Tensor):
559511
mel_segments = torch.from_numpy(mel_segments)
560-
mel_segments = mel_segments.to(device=self.device, dtype=tensor_dtype)
512+
mel_segments = mel_segments.to(self.device)
561513

562514
# Move teacher_emb and teacher_segment_embs to correct device/dtype if tensor
563515
if isinstance(teacher_emb, np.ndarray):
564516
teacher_emb = torch.from_numpy(teacher_emb)
565517
if isinstance(teacher_emb, torch.Tensor):
566-
teacher_emb = teacher_emb.to(device=self.device, dtype=tensor_dtype)
518+
teacher_emb = teacher_emb.to(self.device)
567519
if teacher_segment_embs is not None:
568520
teacher_segment_embs = [torch.from_numpy(e) if isinstance(e, np.ndarray) else e for e in teacher_segment_embs]
569-
teacher_segment_embs = [e.to(device=self.device, dtype=tensor_dtype) if isinstance(e, torch.Tensor) else e for e in teacher_segment_embs]
521+
teacher_segment_embs = [e.to(self.device) if isinstance(e, torch.Tensor) else e for e in teacher_segment_embs]
570522

571523
if mel_segments.shape[0] < 2:
572524
logger.warning(f"⚠️ Skipping song {batch['song_ids'][i]} - only {mel_segments.shape[0]} segment (BatchNorm needs ≥2)")
@@ -580,7 +532,7 @@ def train_step(self, batch: Dict) -> Dict:
580532
chunk_end = min(chunk_start + chunk_size, mel_segments.shape[0])
581533
chunk = mel_segments[chunk_start:chunk_end]
582534
# Ensure chunk is on correct device/dtype
583-
chunk = chunk.to(device=self.device, dtype=tensor_dtype)
535+
chunk = chunk.to(self.device)
584536
chunk_embeddings = self.model.forward(chunk)
585537
segment_embeddings_list.append(chunk_embeddings)
586538

@@ -600,7 +552,7 @@ def train_step(self, batch: Dict) -> Dict:
600552
for chunk_start in range(0, mel_segments.shape[0], chunk_size):
601553
chunk_end = min(chunk_start + chunk_size, mel_segments.shape[0])
602554
chunk = mel_segments[chunk_start:chunk_end]
603-
chunk = chunk.to(device=self.device, dtype=tensor_dtype)
555+
chunk = chunk.to(self.device)
604556
chunk_embeddings = self.model.forward(chunk)
605557
segment_embeddings_list.append(chunk_embeddings)
606558

@@ -618,7 +570,7 @@ def train_step(self, batch: Dict) -> Dict:
618570
for chunk_start in range(0, mel_segments.shape[0], chunk_size):
619571
chunk_end = min(chunk_start + chunk_size, mel_segments.shape[0])
620572
chunk = mel_segments[chunk_start:chunk_end]
621-
chunk = chunk.to(device=self.device, dtype=tensor_dtype)
573+
chunk = chunk.to(self.device)
622574
chunk_embeddings = self.model.forward(chunk)
623575
segment_embeddings_list.append(chunk_embeddings)
624576

@@ -660,10 +612,10 @@ def train_step(self, batch: Dict) -> Dict:
660612
}
661613

662614
# Concatenate and ensure all embeddings are on correct device/dtype
663-
student_embeddings = torch.cat(student_embeddings, dim=0).to(device=self.device, dtype=tensor_dtype)
615+
student_embeddings = torch.cat(student_embeddings, dim=0).to(self.device)
664616
teacher_embeddings = [torch.from_numpy(e) if isinstance(e, np.ndarray) else e for e in teacher_embeddings]
665-
teacher_embeddings = [e.to(device=self.device, dtype=tensor_dtype) if isinstance(e, torch.Tensor) else e for e in teacher_embeddings]
666-
teacher_embeddings = torch.cat([e.unsqueeze(0) if e.dim() == 1 else e for e in teacher_embeddings], dim=0).to(device=self.device, dtype=tensor_dtype)
617+
teacher_embeddings = [e.to(self.device) if isinstance(e, torch.Tensor) else e for e in teacher_embeddings]
618+
teacher_embeddings = torch.cat([e.unsqueeze(0) if e.dim() == 1 else e for e in teacher_embeddings], dim=0).to(self.device)
667619

668620
loss, loss_dict = self.compute_loss(student_embeddings, teacher_embeddings)
669621

student_clap/train_real.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ def train_epoch_real(trainer: StudentCLAPTrainer,
9090
Dict with epoch metrics
9191
"""
9292
# Print device, precision, LR, WD at epoch start
93-
device, dtype = trainer.device_dtype if hasattr(trainer, 'device_dtype') else (trainer.device, getattr(trainer, 'dtype', 'float32'))
93+
device = trainer.device
9494
lr = trainer.optimizer.param_groups[0]['lr']
9595
wd = trainer.optimizer.param_groups[0].get('weight_decay', None)
96-
logger.info(f"🚀 REAL ONNX TRAINING - Epoch {epoch}/{config['training']['epochs']} | Device: {device} | Precision: {dtype} | LR: {lr} | WD: {wd}")
96+
logger.info(f"🚀 REAL ONNX TRAINING - Epoch {epoch}/{config['training']['epochs']} | Device: {device} | LR: {lr} | WD: {wd}")
9797

9898
batch_size = config['training']['batch_size']
9999

@@ -279,22 +279,7 @@ def validate_real(trainer: StudentCLAPTrainer,
279279
logger.info(f"🔍 Running REAL validation (Epoch {epoch})...")
280280

281281
trainer.model.eval()
282-
# Set model to correct dtype for platform (match training)
283-
if torch.cuda.is_available() and str(trainer.device) == 'cuda':
284-
trainer.model.to(trainer.device)
285-
trainer._cast_nonbatchnorm_to_dtype(torch.bfloat16)
286-
trainer._cast_batchnorm_to_float32()
287-
tensor_dtype = torch.bfloat16
288-
elif torch.backends.mps.is_available() and str(trainer.device) == 'mps':
289-
trainer.model.to(trainer.device)
290-
trainer._cast_nonbatchnorm_to_dtype(torch.bfloat16)
291-
trainer._cast_batchnorm_to_float32()
292-
tensor_dtype = torch.bfloat16
293-
else:
294-
trainer.model.to(trainer.device)
295-
trainer._cast_nonbatchnorm_to_dtype(torch.float32)
296-
trainer._cast_batchnorm_to_float32()
297-
tensor_dtype = torch.float32
282+
trainer.model.to(trainer.device)
298283

299284
# Collect embeddings
300285
student_embeddings_list = []
@@ -314,10 +299,9 @@ def validate_real(trainer: StudentCLAPTrainer,
314299

315300
for item in batch_data:
316301
audio_segments = item['audio_segments']
317-
# Move to correct device/dtype
318302
if not isinstance(audio_segments, torch.Tensor):
319303
audio_segments = torch.from_numpy(audio_segments)
320-
audio_segments = audio_segments.to(device=trainer.device, dtype=tensor_dtype)
304+
audio_segments = audio_segments.to(device=trainer.device)
321305
batch['audio_segments'].append(audio_segments)
322306
batch['teacher_embeddings'].append(item['teacher_embedding'])
323307
batch['song_ids'].append(item['item_id'])
@@ -330,7 +314,7 @@ def validate_real(trainer: StudentCLAPTrainer,
330314
# audio_segments are PRE-COMPUTED mel spectrograms! (num_segments, 1, 128, time)
331315
if not isinstance(audio_segments, torch.Tensor):
332316
audio_segments = torch.from_numpy(audio_segments)
333-
audio_segments = audio_segments.to(dtype=tensor_dtype, device=trainer.device)
317+
audio_segments = audio_segments.to(device=trainer.device)
334318
# ⚠️ SKIP SONGS WITH ONLY 1 SEGMENT (BatchNorm requires at least 2 samples)
335319
if audio_segments.shape[0] < 2:
336320
logger.warning(f"⚠️ Skipping song {batch['song_ids'][i]} in validation - only {audio_segments.shape[0]} segment")

0 commit comments

Comments
 (0)