Skip to content

Latest commit

 

History

History
127 lines (86 loc) · 6.17 KB

File metadata and controls

127 lines (86 loc) · 6.17 KB

Apple Metal (MPS) GPU Optimization

This document describes the optimizations made to enable fast music generation on Apple Silicon (M1/M2/M3) devices using Metal Performance Shaders (MPS).

Problem

PR #11 fixed a blocker but generation was running very slowly on Apple Silicon, most likely falling back to CPU instead of utilizing the GPU.

Root Cause

The code was using torch.float32 precision for models on MPS devices. While MPS supports float32, it is significantly slower than float16 operations. MPS is optimized for float16 (half-precision) operations which leverage the GPU's native capabilities.

Solution

1. Float16 Precision (Critical Performance Fix)

Changed model dtype from torch.float32 to torch.float16 for both HeartMuLa and HeartCodec models when running on MPS devices.

Why this matters:

  • MPS has native hardware acceleration for float16 operations
  • float32 operations on MPS may fall back to slower execution paths
  • float16 on MPS is typically 2-4x faster than float32
  • Memory usage is also reduced by half

2. Explicit Device Management

Added verification and automatic correction for model device placement:

  • Verify models are loaded on MPS after initialization
  • Automatically move models to MPS if they end up on wrong device
  • Explicitly set pipeline device and dtype attributes

3. MPS Fallback Configuration

Set PYTORCH_ENABLE_MPS_FALLBACK=1 environment variable to enable graceful CPU fallback for any operations not yet supported by MPS, preventing crashes while maintaining GPU acceleration for supported operations.

4. Consistent Dtype Handling

Ensured that lazy-loaded models (like HeartCodec) use the same dtype as the pipeline configuration instead of hardcoded values.

Technical Details

Changes Made

  1. backend/app/services/music_service.py (top of file):

    • Added MPS configuration at module import time
    • Set PYTORCH_ENABLE_MPS_FALLBACK=1 environment variable
  2. Model Loading (MPS pipeline initialization):

    • Changed from torch.float32 to torch.float16 for MPS
    • Added device verification after model loading
    • Explicitly set pipeline attributes: mula_device, codec_device, mula_dtype, codec_dtype
    • Added automatic device correction if models are on wrong device
  3. Lazy Codec Loading (codec loading function):

    • Use pipeline.codec_dtype instead of hardcoded torch.float32
    • Added MPS-specific logging
  4. Generation Logging (generation start):

    • Added diagnostic logging to show device and dtype at generation start

Performance Impact

Expected performance improvements on Apple Silicon:

  • 2-4x faster generation compared to float32
  • Reduced memory usage (float16 uses half the memory of float32)
  • Full GPU utilization instead of CPU fallback

Testing

To verify the optimizations are working:

  1. Check the logs during model loading - you should see:

    [Apple Metal] Loading HeartMuLa and HeartCodec on MPS (generation + decode)
    [Apple Metal] HeartMuLa model device: mps:0
    [Apple Metal] HeartCodec model device: mps:0
    [Apple Metal] MPS pipeline: HeartMuLa and HeartCodec on MPS
    
  2. During generation, you should see:

    [Generation] Starting generation on device: mps:0 (dtype: torch.float16)
    
  3. Monitor Activity Monitor → GPU History - you should see GPU utilization during generation

MPS Compatibility Notes

  • Supported Operations: Most PyTorch operations work well on MPS
  • Float16 vs Float32: MPS strongly prefers float16 for performance
  • Bfloat16: Not supported on MPS, use float16 instead
  • Quantization: 4-bit quantization (BitsAndBytes) is CUDA-only, not available on MPS
  • Torch.compile: Not yet optimized for MPS, disabled for Apple Silicon
  • Unified Memory: MPS uses unified memory architecture, no explicit VRAM limits

Decode on MPS: limitation and fix

Limitation: HeartCodec’s decoder (heartlib) uses torch.jit.script for the snake() activation in ScalarModel (heartlib/heartcodec/models/sq_codec.py). JIT has limited MPS support and can cause invalid type: 'torch.mps.FloatTensor' or silent CPU fallback when decode runs on Metal, so “Converting Audio” stays slow.

Fix (baseline extension): We extend our baseline for full MPS support by patching heartlib at runtime when MPS is available (same idea as SpeechBrain PR #1805 / issue #1794):

  1. snake() – In music_service.py we replace sq_codec.snake with an eager (non-JIT) implementation so the decode graph runs natively on MPS.
  2. .type(tensor.type()) – heartlib’s PixArtAlphaCombinedFlowEmbeddings.timestep_embedding uses .type(timesteps.type()), which on MPS can raise invalid type: 'torch.mps.FloatTensor'. We patch it to use .to(device=..., dtype=...) instead (match device and dtype explicitly, like the SpeechBrain fix).
  • Patches run before importing HeartCodec so all decode code paths use the fixed versions.
  • No fork of heartlib is required; patches are applied once at module load when torch.backends.mps.is_available().

Port behavior (unchanged):

  1. Load: HeartMuLa and HeartCodec both on MPS with float16.
  2. Decode: detokenize(frames_for_codec) runs on MPS with the patched snake. If decode still raises (e.g. another op rejects MPS), we fall back to CPU for that call, then restore codec to MPS.
  3. Save (same as CPU): We always convert the waveform to CPU float32 and round-trip via numpy before torchaudio.save(), so the backend never sees an MPS tensor.

Result: generation and decoding run on Apple Metal (M4 etc.); only the final save step uses CPU tensors for compatibility.

Future Optimizations

Potential areas for further optimization:

  1. Profile specific operations to identify any remaining CPU fallbacks
  2. Consider using Metal Performance Shaders directly for certain operations
  3. Explore torch.compile support as it matures for MPS
  4. Investigate mixed precision training/inference techniques

References