Skip to content

Commit c6f2b85

Browse files
committed
LOCAL STUDENT - fix validation
1 parent 708ae40 commit c6f2b85

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

student_clap/models/student_onnx_model.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -498,13 +498,19 @@ def train_step(self, batch: Dict) -> Dict:
498498
"""
499499
Single training step on a batch.
500500
501-
Args:
502-
batch: Dictionary with:
503-
- 'audio_segments': List of audio segment tensors per song
504-
- 'teacher_embeddings': Teacher embeddings from database
505-
- 'song_ids': Song IDs for logging
506-
507-
Returns:
501+
if torch.cuda.is_available() and str(self.device) == 'cuda':
502+
self.model.to(self.device, dtype=torch.bfloat16)
503+
self._cast_batchnorm_to_dtype(torch.bfloat16)
504+
tensor_dtype = torch.bfloat16
505+
elif torch.backends.mps.is_available() and str(self.device) == 'mps':
506+
self.model.to(self.device, dtype=torch.bfloat16)
507+
self._cast_batchnorm_to_dtype(torch.bfloat16)
508+
tensor_dtype = torch.bfloat16
509+
else:
510+
self.model.to(self.device, dtype=torch.float32)
511+
self._cast_batchnorm_to_dtype(torch.float32)
512+
tensor_dtype = torch.float32
513+
self.model.train()
508514
step_metrics: Dictionary with loss and performance metrics
509515
"""
510516

student_clap/train_real.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,15 @@ def validate_real(trainer: StudentCLAPTrainer,
282282
# Set model to correct dtype for platform (match training)
283283
if torch.cuda.is_available() and str(trainer.device) == 'cuda':
284284
trainer.model.to(trainer.device, dtype=torch.bfloat16)
285+
trainer._cast_batchnorm_to_dtype(torch.bfloat16)
285286
tensor_dtype = torch.bfloat16
286287
elif torch.backends.mps.is_available() and str(trainer.device) == 'mps':
287288
trainer.model.to(trainer.device, dtype=torch.bfloat16)
289+
trainer._cast_batchnorm_to_dtype(torch.bfloat16)
288290
tensor_dtype = torch.bfloat16
289291
else:
290292
trainer.model.to(trainer.device, dtype=torch.float32)
293+
trainer._cast_batchnorm_to_dtype(torch.float32)
291294
tensor_dtype = torch.float32
292295

293296
# Collect embeddings

0 commit comments

Comments
 (0)