@@ -330,7 +330,7 @@ def process_audio_segments(self, audio_segments: torch.Tensor) -> torch.Tensor:
330330 """
331331
332332 model_device = next (self .parameters ()).device
333- audio_segments = audio_segments .to (model_device , dtype = torch . float32 )
333+ audio_segments = audio_segments .to (model_device )
334334
335335 mel_specs = self .compute_mel_spectrogram (audio_segments )
336336
@@ -356,21 +356,6 @@ def count_parameters(self) -> Dict[str, int]:
356356 }
357357
358358class StudentCLAPTrainer :
359- def _cast_batchnorm_to_float32 (self ):
360- """Cast all BatchNorm layers in the model to float32 for all platforms (CUDA, Mac MPS, CPU)."""
361- for module in self .model .modules ():
362- if isinstance (module , torch .nn .modules .batchnorm ._BatchNorm ):
363- module .to (dtype = torch .float32 )
364-
365- def _cast_nonbatchnorm_to_dtype (self , dtype ):
366- """Cast all non-BatchNorm layers in the model to the given dtype."""
367- for module in self .model .modules ():
368- if not isinstance (module , torch .nn .modules .batchnorm ._BatchNorm ):
369- if hasattr (module , 'to' ):
370- try :
371- module .to (dtype = dtype )
372- except Exception :
373- pass
374359 """
375360 ONNX-compatible trainer for Student CLAP using PyTorch.
376361
@@ -382,18 +367,15 @@ def _cast_nonbatchnorm_to_dtype(self, dtype):
382367 def __init__ (self , config : Dict ):
383368 self .config = config
384369
385- # --- Device and precision autodetection ---
370+ # --- Device autodetection, always use float32 ---
386371 if torch .cuda .is_available ():
387372 self .device = torch .device ('cuda' )
388- self .dtype = torch .bfloat16
389373 elif torch .backends .mps .is_available ():
390374 self .device = torch .device ('mps' )
391- self .dtype = torch .bfloat16 # Use bfloat16 for Mac (MPS)
392375 else :
393376 self .device = torch .device ('cpu' )
394- self .dtype = torch .float32
395377
396- self .model = StudentCLAPAudio (config ).to (self .device , dtype = self . dtype )
378+ self .model = StudentCLAPAudio (config ).to (self .device )
397379
398380 # Support configurable optimizer: 'adam' (default) or 'adamw'
399381 optimizer_type = config ['training' ].get ('optimizer' , 'adam' ).lower ()
@@ -432,13 +414,11 @@ def __init__(self, config: Dict):
432414 logger .info ("🔒 STAGE 2: Freezing encoder, training projection head only" )
433415 self ._freeze_encoder ()
434416
435- logger .info (f"Initialized Student CLAP trainer on { self .device } (precision: { self . dtype } ) " )
417+ logger .info (f"Initialized Student CLAP trainer on { self .device } " )
436418 logger .info (f"Model parameters: { self .model .count_parameters ()} " )
437419 logger .info (f"Training strategy: { self .training_strategy } " )
438420
439- @property
440- def device_dtype (self ):
441- return self .device , self .dtype
421+ #
442422
443423 def _freeze_encoder (self ):
444424 """Freeze encoder layers, keep only projection head trainable (Stage 2)."""
@@ -474,20 +454,12 @@ def compute_loss(self,
474454 loss_dict: Individual loss components for logging
475455 """
476456
477- # Always cast both tensors to the same dtype as the model/device
478- target_dtype = self .dtype
479- if torch .cuda .is_available () and str (self .device ) == 'cuda' :
480- target_dtype = torch .bfloat16
481- elif torch .backends .mps .is_available () and str (self .device ) == 'mps' :
482- target_dtype = torch .float16
483- else :
484- target_dtype = torch .float32
485-
457+ # Always use default float32 for all tensors
486458 if not isinstance (teacher_embeddings , torch .Tensor ):
487- teacher_embeddings = torch .from_numpy (teacher_embeddings ).to (dtype = target_dtype , device = self .device )
459+ teacher_embeddings = torch .from_numpy (teacher_embeddings ).to (self .device )
488460 else :
489- teacher_embeddings = teacher_embeddings .to (dtype = target_dtype , device = self .device )
490- student_embeddings = student_embeddings .to (dtype = target_dtype , device = self .device )
461+ teacher_embeddings = teacher_embeddings .to (self .device )
462+ student_embeddings = student_embeddings .to (self .device )
491463
492464 teacher_embeddings = F .normalize (teacher_embeddings , p = 2 , dim = 1 )
493465 student_embeddings = F .normalize (student_embeddings , p = 2 , dim = 1 )
@@ -513,32 +485,13 @@ def train_step(self, batch: Dict) -> Dict:
513485 """
514486 Single training step on a batch.
515487
516- if torch.cuda.is_available() and str(self.device) == 'cuda':
517- self.model.to(self.device)
518- self._cast_nonbatchnorm_to_dtype(torch.bfloat16)
519- self._cast_batchnorm_to_float32()
520- tensor_dtype = torch.bfloat16
521- elif torch.backends.mps.is_available() and str(self.device) == 'mps':
522- self.model.to(self.device)
523- self._cast_nonbatchnorm_to_dtype(torch.bfloat16)
524- self._cast_batchnorm_to_float32()
525- tensor_dtype = torch.bfloat16
526- else:
527- self.model.to(self.device)
528- self._cast_nonbatchnorm_to_dtype(torch.float32)
529- self._cast_batchnorm_to_float32()
530- tensor_dtype = torch.float32
488+ # Always use default float32 for training
489+ self.model.to(self.device)
531490 self.model.train()
532491 step_metrics: Dictionary with loss and performance metrics
533492 """
534493
535- # Only patch CUDA: force bfloat16, otherwise use autodetected self.dtype (float16 for MPS, float32 for CPU)
536- if torch .cuda .is_available () and str (self .device ) == 'cuda' :
537- self .model .to (self .device , dtype = torch .bfloat16 )
538- tensor_dtype = torch .bfloat16
539- else :
540- self .model .to (self .device , dtype = self .dtype )
541- tensor_dtype = self .dtype
494+ self .model .to (self .device )
542495 self .model .train ()
543496
544497 if self .accumulation_counter == 0 :
@@ -553,20 +506,19 @@ def train_step(self, batch: Dict) -> Dict:
553506 batch .get ('teacher_segment_embeddings' , [None ] * len (batch ['audio_segments' ]))
554507 )):
555508
556- # Only use bfloat16 for CUDA, float32 everywhere else
557- # Move mel_segments to correct device/dtype
509+ # Always use default float32 for all input tensors
558510 if not isinstance (mel_segments , torch .Tensor ):
559511 mel_segments = torch .from_numpy (mel_segments )
560- mel_segments = mel_segments .to (device = self .device , dtype = tensor_dtype )
512+ mel_segments = mel_segments .to (self .device )
561513
562514 # Move teacher_emb and teacher_segment_embs to correct device/dtype if tensor
563515 if isinstance (teacher_emb , np .ndarray ):
564516 teacher_emb = torch .from_numpy (teacher_emb )
565517 if isinstance (teacher_emb , torch .Tensor ):
566- teacher_emb = teacher_emb .to (device = self .device , dtype = tensor_dtype )
518+ teacher_emb = teacher_emb .to (self .device )
567519 if teacher_segment_embs is not None :
568520 teacher_segment_embs = [torch .from_numpy (e ) if isinstance (e , np .ndarray ) else e for e in teacher_segment_embs ]
569- teacher_segment_embs = [e .to (device = self .device , dtype = tensor_dtype ) if isinstance (e , torch .Tensor ) else e for e in teacher_segment_embs ]
521+ teacher_segment_embs = [e .to (self .device ) if isinstance (e , torch .Tensor ) else e for e in teacher_segment_embs ]
570522
571523 if mel_segments .shape [0 ] < 2 :
572524 logger .warning (f"⚠️ Skipping song { batch ['song_ids' ][i ]} - only { mel_segments .shape [0 ]} segment (BatchNorm needs ≥2)" )
@@ -580,7 +532,7 @@ def train_step(self, batch: Dict) -> Dict:
580532 chunk_end = min (chunk_start + chunk_size , mel_segments .shape [0 ])
581533 chunk = mel_segments [chunk_start :chunk_end ]
582534 # Ensure chunk is on correct device/dtype
583- chunk = chunk .to (device = self .device , dtype = tensor_dtype )
535+ chunk = chunk .to (self .device )
584536 chunk_embeddings = self .model .forward (chunk )
585537 segment_embeddings_list .append (chunk_embeddings )
586538
@@ -600,7 +552,7 @@ def train_step(self, batch: Dict) -> Dict:
600552 for chunk_start in range (0 , mel_segments .shape [0 ], chunk_size ):
601553 chunk_end = min (chunk_start + chunk_size , mel_segments .shape [0 ])
602554 chunk = mel_segments [chunk_start :chunk_end ]
603- chunk = chunk .to (device = self .device , dtype = tensor_dtype )
555+ chunk = chunk .to (self .device )
604556 chunk_embeddings = self .model .forward (chunk )
605557 segment_embeddings_list .append (chunk_embeddings )
606558
@@ -618,7 +570,7 @@ def train_step(self, batch: Dict) -> Dict:
618570 for chunk_start in range (0 , mel_segments .shape [0 ], chunk_size ):
619571 chunk_end = min (chunk_start + chunk_size , mel_segments .shape [0 ])
620572 chunk = mel_segments [chunk_start :chunk_end ]
621- chunk = chunk .to (device = self .device , dtype = tensor_dtype )
573+ chunk = chunk .to (self .device )
622574 chunk_embeddings = self .model .forward (chunk )
623575 segment_embeddings_list .append (chunk_embeddings )
624576
@@ -660,10 +612,10 @@ def train_step(self, batch: Dict) -> Dict:
660612 }
661613
662614 # Concatenate and ensure all embeddings are on correct device/dtype
663- student_embeddings = torch .cat (student_embeddings , dim = 0 ).to (device = self .device , dtype = tensor_dtype )
615+ student_embeddings = torch .cat (student_embeddings , dim = 0 ).to (self .device )
664616 teacher_embeddings = [torch .from_numpy (e ) if isinstance (e , np .ndarray ) else e for e in teacher_embeddings ]
665- teacher_embeddings = [e .to (device = self .device , dtype = tensor_dtype ) if isinstance (e , torch .Tensor ) else e for e in teacher_embeddings ]
666- teacher_embeddings = torch .cat ([e .unsqueeze (0 ) if e .dim () == 1 else e for e in teacher_embeddings ], dim = 0 ).to (device = self .device , dtype = tensor_dtype )
617+ teacher_embeddings = [e .to (self .device ) if isinstance (e , torch .Tensor ) else e for e in teacher_embeddings ]
618+ teacher_embeddings = torch .cat ([e .unsqueeze (0 ) if e .dim () == 1 else e for e in teacher_embeddings ], dim = 0 ).to (self .device )
667619
668620 loss , loss_dict = self .compute_loss (student_embeddings , teacher_embeddings )
669621
0 commit comments