File tree Expand file tree Collapse file tree 2 files changed +16
-7
lines changed
Expand file tree Collapse file tree 2 files changed +16
-7
lines changed Original file line number Diff line number Diff line change @@ -498,13 +498,19 @@ def train_step(self, batch: Dict) -> Dict:
498498 """
499499 Single training step on a batch.
500500
501- Args:
502- batch: Dictionary with:
503- - 'audio_segments': List of audio segment tensors per song
504- - 'teacher_embeddings': Teacher embeddings from database
505- - 'song_ids': Song IDs for logging
506-
507- Returns:
501+ if torch.cuda.is_available() and str(self.device) == 'cuda':
502+ self.model.to(self.device, dtype=torch.bfloat16)
503+ self._cast_batchnorm_to_dtype(torch.bfloat16)
504+ tensor_dtype = torch.bfloat16
505+ elif torch.backends.mps.is_available() and str(self.device) == 'mps':
506+ self.model.to(self.device, dtype=torch.bfloat16)
507+ self._cast_batchnorm_to_dtype(torch.bfloat16)
508+ tensor_dtype = torch.bfloat16
509+ else:
510+ self.model.to(self.device, dtype=torch.float32)
511+ self._cast_batchnorm_to_dtype(torch.float32)
512+ tensor_dtype = torch.float32
513+ self.model.train()
508514 step_metrics: Dictionary with loss and performance metrics
509515 """
510516
Original file line number Diff line number Diff line change @@ -282,12 +282,15 @@ def validate_real(trainer: StudentCLAPTrainer,
282282 # Set model to correct dtype for platform (match training)
283283 if torch .cuda .is_available () and str (trainer .device ) == 'cuda' :
284284 trainer .model .to (trainer .device , dtype = torch .bfloat16 )
285+ trainer ._cast_batchnorm_to_dtype (torch .bfloat16 )
285286 tensor_dtype = torch .bfloat16
286287 elif torch .backends .mps .is_available () and str (trainer .device ) == 'mps' :
287288 trainer .model .to (trainer .device , dtype = torch .bfloat16 )
289+ trainer ._cast_batchnorm_to_dtype (torch .bfloat16 )
288290 tensor_dtype = torch .bfloat16
289291 else :
290292 trainer .model .to (trainer .device , dtype = torch .float32 )
293+ trainer ._cast_batchnorm_to_dtype (torch .float32 )
291294 tensor_dtype = torch .float32
292295
293296 # Collect embeddings
You can’t perform that action at this time.
0 commit comments