Skip to content

Commit 033a8a4

Browse files
Copilotlmangani
andcommitted
Fix MPS FloatTensor error by preserving torch.long dtype during frame stacking
Co-authored-by: lmangani <1423657+lmangani@users.noreply.github.com>
1 parent 1431967 commit 033a8a4

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

backend/app/services/music_service.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,9 @@ def _pad_audio_token(token):
804804
progress = int((i + 1) / max_audio_frames * 100)
805805
callback(progress, f"Generating audio... {i + 1}/{max_audio_frames} frames")
806806

807-
frames = torch.stack(frames).permute(1, 2, 0).squeeze(0).cpu() # Move to CPU immediately
807+
# Stack frames and explicitly preserve torch.long dtype (critical for MPS compatibility)
808+
# torch.stack may promote dtype to float on MPS, so we explicitly convert to long before CPU
809+
frames = torch.stack(frames).permute(1, 2, 0).squeeze(0).to(dtype=torch.long).cpu()
808810

809811
# Sequential offload: Move HeartMuLa to CPU before loading HeartCodec
810812
# This allows fitting on smaller GPUs (12GB) by never having both models in VRAM

0 commit comments

Comments
 (0)