Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pytorch_lightning==2.5.1
soundfile==0.13.1
torch
torchaudio
torchcodec
torchvision
tqdm
transformers==4.50.0
Expand Down
73 changes: 61 additions & 12 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import matplotlib
import torch.nn.functional as F
import torch.utils.data
from enum import Enum
from pytorch_lightning.core import LightningModule
from torch.utils.data import DataLoader
from acestep.schedulers.scheduling_flow_match_euler_discrete import (
Expand All @@ -33,6 +34,22 @@
torch.set_float32_matmul_precision("high")


class DeviceType(Enum):
""" Device type constants for PyTorch operations """
CUDA = "cuda"
MPS = "mps"
CPU = "cpu"

@staticmethod
def detect() -> 'DeviceType':
"""Detect available hardware platform"""
if torch.cuda.is_available():
return DeviceType.CUDA
if torch.backends.mps.is_available():
return DeviceType.MPS
return DeviceType.CPU


class Pipeline(LightningModule):
def __init__(
self,
Expand Down Expand Up @@ -210,7 +227,7 @@ def infer_mert_ssl(self, target_wavs, wav_lengths):
chunk_idx = 0
for i in range(bsz):
audio_chunks = chunk_hidden_states[
chunk_idx : chunk_idx + num_chunks_per_audio[i]
chunk_idx: chunk_idx + num_chunks_per_audio[i]
]
audio_hidden = torch.cat(
audio_chunks, dim=0
Expand Down Expand Up @@ -287,7 +304,7 @@ def infer_mhubert_ssl(self, target_wavs, wav_lengths):
chunk_idx = 0
for i in range(bsz):
audio_chunks = chunk_hidden_states[
chunk_idx : chunk_idx + num_chunks_per_audio[i]
chunk_idx: chunk_idx + num_chunks_per_audio[i]
]
audio_hidden = torch.cat(
audio_chunks, dim=0
Expand Down Expand Up @@ -325,7 +342,15 @@ def preprocess(self, batch, train=True):
mert_ssl_hidden_states = None
mhubert_ssl_hidden_states = None
if train:
with torch.amp.autocast(device_type="cuda", dtype=dtype):
device_type = DeviceType.detect()
if device_type == DeviceType.CUDA:
with torch.amp.autocast(device_type=device_type.value, dtype=dtype):
mert_ssl_hidden_states = self.infer_mert_ssl(target_wavs, wav_lengths)
mhubert_ssl_hidden_states = self.infer_mhubert_ssl(
target_wavs, wav_lengths
)
else:
# MPS/CPU: no autocast needed
mert_ssl_hidden_states = self.infer_mert_ssl(target_wavs, wav_lengths)
mhubert_ssl_hidden_states = self.infer_mhubert_ssl(
target_wavs, wav_lengths
Expand Down Expand Up @@ -447,11 +472,16 @@ def train_dataloader(self):
train=True,
train_dataset_path=self.hparams.dataset_path,
)
# Detect device for platform-specific settings
device_type = DeviceType.detect()
use_pin_memory = device_type == DeviceType.CUDA # Only beneficial for CUDA
num_workers = self.hparams.num_workers if device_type == DeviceType.CUDA else 0 # MPS needs 0

return DataLoader(
self.train_dataset,
shuffle=True,
num_workers=self.hparams.num_workers,
pin_memory=True,
num_workers=num_workers,
pin_memory=use_pin_memory,
collate_fn=self.train_dataset.collate_fn,
)

Expand All @@ -474,7 +504,7 @@ def get_timestep(self, bsz, device):
mean=self.hparams.logit_mean,
std=self.hparams.logit_std,
size=(bsz,),
device="cpu",
device=DeviceType.CPU.value,
)
u = torch.nn.functional.sigmoid(u)
indices = (u * self.scheduler.config.num_train_timesteps).long()
Expand Down Expand Up @@ -779,8 +809,8 @@ def plot_step(self, batch, batch_idx):
if (
global_step % self.hparams.every_plot_step != 0
or self.local_rank != 0
or torch.distributed.get_rank() != 0
or torch.cuda.current_device() != 0
or (torch.distributed.is_initialized() and torch.distributed.get_rank() != 0)
or (torch.cuda.is_available() and torch.cuda.current_device() != 0)
):
return
results = self.predict_step(batch)
Expand Down Expand Up @@ -832,20 +862,38 @@ def main(args):
checkpoint_callback = ModelCheckpoint(
monitor=None,
every_n_train_steps=args.every_n_train_steps,
save_top_k=-1,
save_top_k=args.save_top_k,
)
# add datetime str to version
logger_callback = TensorBoardLogger(
version=datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + args.exp_name,
save_dir=args.logger_dir,
)

# Detect hardware
device_type = DeviceType.detect()

# Set defaults based on hardware
if device_type == DeviceType.CUDA:
accelerator = "gpu"
devices = args.devices
strategy = "ddp_find_unused_parameters_true"
elif device_type == DeviceType.MPS:
accelerator = "mps"
devices = 1
strategy = "auto"
else:
accelerator = "cpu"
devices = 1
strategy = "auto"

trainer = Trainer(
accelerator="gpu",
devices=args.devices,
accelerator=accelerator,
devices=devices,
num_nodes=args.num_nodes,
precision=args.precision,
accumulate_grad_batches=args.accumulate_grad_batches,
strategy="ddp_find_unused_parameters_true",
strategy=strategy,
max_epochs=args.epochs,
max_steps=args.max_steps,
log_every_n_steps=1,
Expand Down Expand Up @@ -886,5 +934,6 @@ def main(args):
args.add_argument("--every_plot_step", type=int, default=2000)
args.add_argument("--val_check_interval", type=int, default=None)
args.add_argument("--lora_config_path", type=str, default="config/zh_rap_lora_config.json")
args.add_argument("--save_top_k", type=int, default=-1)
args = args.parse_args()
main(args)