This document describes the optimizations made to enable fast music generation on Apple Silicon (M1/M2/M3) devices using Metal Performance Shaders (MPS).
PR #11 fixed a blocker but generation was running very slowly on Apple Silicon, most likely falling back to CPU instead of utilizing the GPU.
The code was using torch.float32 precision for models on MPS devices. While MPS supports float32, it is significantly slower than float16 operations. MPS is optimized for float16 (half-precision) operations which leverage the GPU's native capabilities.
Changed model dtype from torch.float32 to torch.float16 for both HeartMuLa and HeartCodec models when running on MPS devices.
Why this matters:
- MPS has native hardware acceleration for float16 operations
- float32 operations on MPS may fall back to slower execution paths
- float16 on MPS is typically 2-4x faster than float32
- Memory usage is also reduced by half
Added verification and automatic correction for model device placement:
- Verify models are loaded on MPS after initialization
- Automatically move models to MPS if they end up on wrong device
- Explicitly set pipeline device and dtype attributes
Set PYTORCH_ENABLE_MPS_FALLBACK=1 environment variable to enable graceful CPU fallback for any operations not yet supported by MPS, preventing crashes while maintaining GPU acceleration for supported operations.
Ensured that lazy-loaded models (like HeartCodec) use the same dtype as the pipeline configuration instead of hardcoded values.
-
backend/app/services/music_service.py(top of file):- Added MPS configuration at module import time
- Set
PYTORCH_ENABLE_MPS_FALLBACK=1environment variable
-
Model Loading (MPS pipeline initialization):
- Changed from
torch.float32totorch.float16for MPS - Added device verification after model loading
- Explicitly set pipeline attributes:
mula_device,codec_device,mula_dtype,codec_dtype - Added automatic device correction if models are on wrong device
- Changed from
-
Lazy Codec Loading (codec loading function):
- Use
pipeline.codec_dtypeinstead of hardcodedtorch.float32 - Added MPS-specific logging
- Use
-
Generation Logging (generation start):
- Added diagnostic logging to show device and dtype at generation start
Expected performance improvements on Apple Silicon:
- 2-4x faster generation compared to float32
- Reduced memory usage (float16 uses half the memory of float32)
- Full GPU utilization instead of CPU fallback
To verify the optimizations are working:
-
Check the logs during model loading - you should see:
[Apple Metal] Loading HeartMuLa and HeartCodec on MPS (generation + decode) [Apple Metal] HeartMuLa model device: mps:0 [Apple Metal] HeartCodec model device: mps:0 [Apple Metal] MPS pipeline: HeartMuLa and HeartCodec on MPS -
During generation, you should see:
[Generation] Starting generation on device: mps:0 (dtype: torch.float16) -
Monitor Activity Monitor → GPU History - you should see GPU utilization during generation
- Supported Operations: Most PyTorch operations work well on MPS
- Float16 vs Float32: MPS strongly prefers float16 for performance
- Bfloat16: Not supported on MPS, use float16 instead
- Quantization: 4-bit quantization (BitsAndBytes) is CUDA-only, not available on MPS
- Torch.compile: Not yet optimized for MPS, disabled for Apple Silicon
- Unified Memory: MPS uses unified memory architecture, no explicit VRAM limits
Limitation: HeartCodec’s decoder (heartlib) uses torch.jit.script for the snake() activation in ScalarModel (heartlib/heartcodec/models/sq_codec.py). JIT has limited MPS support and can cause invalid type: 'torch.mps.FloatTensor' or silent CPU fallback when decode runs on Metal, so “Converting Audio” stays slow.
Fix (baseline extension): We extend our baseline for full MPS support by patching heartlib at runtime when MPS is available (same idea as SpeechBrain PR #1805 / issue #1794):
- snake() – In
music_service.pywe replacesq_codec.snakewith an eager (non-JIT) implementation so the decode graph runs natively on MPS. .type(tensor.type())– heartlib’sPixArtAlphaCombinedFlowEmbeddings.timestep_embeddinguses.type(timesteps.type()), which on MPS can raiseinvalid type: 'torch.mps.FloatTensor'. We patch it to use.to(device=..., dtype=...)instead (match device and dtype explicitly, like the SpeechBrain fix).
- Patches run before importing HeartCodec so all decode code paths use the fixed versions.
- No fork of heartlib is required; patches are applied once at module load when
torch.backends.mps.is_available().
Port behavior (unchanged):
- Load: HeartMuLa and HeartCodec both on MPS with float16.
- Decode:
detokenize(frames_for_codec)runs on MPS with the patched snake. If decode still raises (e.g. another op rejects MPS), we fall back to CPU for that call, then restore codec to MPS. - Save (same as CPU): We always convert the waveform to CPU float32 and round-trip via numpy before
torchaudio.save(), so the backend never sees an MPS tensor.
Result: generation and decoding run on Apple Metal (M4 etc.); only the final save step uses CPU tensors for compatibility.
Potential areas for further optimization:
- Profile specific operations to identify any remaining CPU fallbacks
- Consider using Metal Performance Shaders directly for certain operations
- Explore torch.compile support as it matures for MPS
- Investigate mixed precision training/inference techniques