You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This document describes the optimizations made to enable fast music generation on Apple Silicon (M1/M2/M3) devices using Metal Performance Shaders (MPS).
4
+
5
+
## Problem
6
+
7
+
PR #11 fixed a blocker but generation was running very slowly on Apple Silicon, most likely falling back to CPU instead of utilizing the GPU.
8
+
9
+
## Root Cause
10
+
11
+
The code was using `torch.float32` precision for models on MPS devices. While MPS supports float32, it is **significantly slower** than float16 operations. MPS is optimized for float16 (half-precision) operations which leverage the GPU's native capabilities.
Changed model dtype from `torch.float32` to `torch.float16` for both HeartMuLa and HeartCodec models when running on MPS devices.
18
+
19
+
**Why this matters:**
20
+
- MPS has native hardware acceleration for float16 operations
21
+
- float32 operations on MPS may fall back to slower execution paths
22
+
- float16 on MPS is typically **2-4x faster** than float32
23
+
- Memory usage is also reduced by half
24
+
25
+
### 2. Explicit Device Management
26
+
27
+
Added verification and automatic correction for model device placement:
28
+
- Verify models are loaded on MPS after initialization
29
+
- Automatically move models to MPS if they end up on wrong device
30
+
- Explicitly set pipeline device and dtype attributes
31
+
32
+
### 3. MPS Fallback Configuration
33
+
34
+
Set `PYTORCH_ENABLE_MPS_FALLBACK=1` environment variable to enable graceful CPU fallback for any operations not yet supported by MPS, preventing crashes while maintaining GPU acceleration for supported operations.
35
+
36
+
### 4. Consistent Dtype Handling
37
+
38
+
Ensured that lazy-loaded models (like HeartCodec) use the same dtype as the pipeline configuration instead of hardcoded values.
39
+
40
+
## Technical Details
41
+
42
+
### Changes Made
43
+
44
+
1.**`backend/app/services/music_service.py`** (top of file):
45
+
- Added MPS configuration at module import time
46
+
- Set `PYTORCH_ENABLE_MPS_FALLBACK=1` environment variable
0 commit comments