diff --git a/requirements.txt b/requirements.txt index 15d61cd8..fb780476 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ pytorch_lightning==2.5.1 soundfile==0.13.1 torch torchaudio +torchcodec torchvision tqdm transformers==4.50.0 diff --git a/trainer.py b/trainer.py index a3ed9cbc..96d1775f 100644 --- a/trainer.py +++ b/trainer.py @@ -8,6 +8,7 @@ import matplotlib import torch.nn.functional as F import torch.utils.data +from enum import Enum from pytorch_lightning.core import LightningModule from torch.utils.data import DataLoader from acestep.schedulers.scheduling_flow_match_euler_discrete import ( @@ -33,6 +34,22 @@ torch.set_float32_matmul_precision("high") +class DeviceType(Enum): + """ Device type constants for PyTorch operations """ + CUDA = "cuda" + MPS = "mps" + CPU = "cpu" + + @staticmethod + def detect() -> 'DeviceType': + """Detect available hardware platform""" + if torch.cuda.is_available(): + return DeviceType.CUDA + if torch.backends.mps.is_available(): + return DeviceType.MPS + return DeviceType.CPU + + class Pipeline(LightningModule): def __init__( self, @@ -210,7 +227,7 @@ def infer_mert_ssl(self, target_wavs, wav_lengths): chunk_idx = 0 for i in range(bsz): audio_chunks = chunk_hidden_states[ - chunk_idx : chunk_idx + num_chunks_per_audio[i] + chunk_idx: chunk_idx + num_chunks_per_audio[i] ] audio_hidden = torch.cat( audio_chunks, dim=0 @@ -287,7 +304,7 @@ def infer_mhubert_ssl(self, target_wavs, wav_lengths): chunk_idx = 0 for i in range(bsz): audio_chunks = chunk_hidden_states[ - chunk_idx : chunk_idx + num_chunks_per_audio[i] + chunk_idx: chunk_idx + num_chunks_per_audio[i] ] audio_hidden = torch.cat( audio_chunks, dim=0 @@ -325,7 +342,15 @@ def preprocess(self, batch, train=True): mert_ssl_hidden_states = None mhubert_ssl_hidden_states = None if train: - with torch.amp.autocast(device_type="cuda", dtype=dtype): + device_type = DeviceType.detect() + if device_type == DeviceType.CUDA: + with torch.amp.autocast(device_type=device_type.value, dtype=dtype): + mert_ssl_hidden_states = self.infer_mert_ssl(target_wavs, wav_lengths) + mhubert_ssl_hidden_states = self.infer_mhubert_ssl( + target_wavs, wav_lengths + ) + else: + # MPS/CPU: no autocast needed mert_ssl_hidden_states = self.infer_mert_ssl(target_wavs, wav_lengths) mhubert_ssl_hidden_states = self.infer_mhubert_ssl( target_wavs, wav_lengths @@ -447,11 +472,16 @@ def train_dataloader(self): train=True, train_dataset_path=self.hparams.dataset_path, ) + # Detect device for platform-specific settings + device_type = DeviceType.detect() + use_pin_memory = device_type == DeviceType.CUDA # Only beneficial for CUDA + num_workers = self.hparams.num_workers if device_type == DeviceType.CUDA else 0 # MPS needs 0 + return DataLoader( self.train_dataset, shuffle=True, - num_workers=self.hparams.num_workers, - pin_memory=True, + num_workers=num_workers, + pin_memory=use_pin_memory, collate_fn=self.train_dataset.collate_fn, ) @@ -474,7 +504,7 @@ def get_timestep(self, bsz, device): mean=self.hparams.logit_mean, std=self.hparams.logit_std, size=(bsz,), - device="cpu", + device=DeviceType.CPU.value, ) u = torch.nn.functional.sigmoid(u) indices = (u * self.scheduler.config.num_train_timesteps).long() @@ -779,8 +809,8 @@ def plot_step(self, batch, batch_idx): if ( global_step % self.hparams.every_plot_step != 0 or self.local_rank != 0 - or torch.distributed.get_rank() != 0 - or torch.cuda.current_device() != 0 + or (torch.distributed.is_initialized() and torch.distributed.get_rank() != 0) + or (torch.cuda.is_available() and torch.cuda.current_device() != 0) ): return results = self.predict_step(batch) @@ -832,20 +862,38 @@ def main(args): checkpoint_callback = ModelCheckpoint( monitor=None, every_n_train_steps=args.every_n_train_steps, - save_top_k=-1, + save_top_k=args.save_top_k, ) # add datetime str to version logger_callback = TensorBoardLogger( version=datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + args.exp_name, save_dir=args.logger_dir, ) + + # Detect hardware + device_type = DeviceType.detect() + + # Set defaults based on hardware + if device_type == DeviceType.CUDA: + accelerator = "gpu" + devices = args.devices + strategy = "ddp_find_unused_parameters_true" + elif device_type == DeviceType.MPS: + accelerator = "mps" + devices = 1 + strategy = "auto" + else: + accelerator = "cpu" + devices = 1 + strategy = "auto" + trainer = Trainer( - accelerator="gpu", - devices=args.devices, + accelerator=accelerator, + devices=devices, num_nodes=args.num_nodes, precision=args.precision, accumulate_grad_batches=args.accumulate_grad_batches, - strategy="ddp_find_unused_parameters_true", + strategy=strategy, max_epochs=args.epochs, max_steps=args.max_steps, log_every_n_steps=1, @@ -886,5 +934,6 @@ def main(args): args.add_argument("--every_plot_step", type=int, default=2000) args.add_argument("--val_check_interval", type=int, default=None) args.add_argument("--lora_config_path", type=str, default="config/zh_rap_lora_config.json") + args.add_argument("--save_top_k", type=int, default=-1) args = args.parse_args() main(args)