Skip to content

Commit 63e150b

Browse files
Copilotlmangani
andcommitted
Address code review feedback: fix comments, add error handling, update docs
Co-authored-by: lmangani <1423657+lmangani@users.noreply.github.com>
1 parent a8ec451 commit 63e150b

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

MPS_OPTIMIZATION.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,21 @@ Ensured that lazy-loaded models (like HeartCodec) use the same dtype as the pipe
4141

4242
### Changes Made
4343

44-
1. **`backend/app/services/music_service.py`** (lines 19-27):
44+
1. **`backend/app/services/music_service.py`** (top of file):
4545
- Added MPS configuration at module import time
4646
- Set `PYTORCH_ENABLE_MPS_FALLBACK=1` environment variable
4747

48-
2. **Model Loading** (lines 1521-1565):
48+
2. **Model Loading** (MPS pipeline initialization):
4949
- Changed from `torch.float32` to `torch.float16` for MPS
5050
- Added device verification after model loading
5151
- Explicitly set pipeline attributes: `mula_device`, `codec_device`, `mula_dtype`, `codec_dtype`
5252
- Added automatic device correction if models are on wrong device
5353

54-
3. **Lazy Codec Loading** (lines 813-829):
54+
3. **Lazy Codec Loading** (codec loading function):
5555
- Use `pipeline.codec_dtype` instead of hardcoded `torch.float32`
5656
- Added MPS-specific logging
5757

58-
4. **Generation Logging** (line 731):
58+
4. **Generation Logging** (generation start):
5959
- Added diagnostic logging to show device and dtype at generation start
6060

6161
### Performance Impact

backend/app/services/music_service.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
from tokenizers import Tokenizer
1818

1919
# Configure MPS (Apple Metal) for optimal performance
20-
# These settings must be set before any PyTorch operations
20+
# Note: PYTORCH_ENABLE_MPS_FALLBACK can be set at runtime for fallback behavior
2121
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
2222
# Enable MPS fallback to CPU for unsupported operations (better than crashing)
23+
# This takes effect for subsequent tensor operations
2324
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
2425
# Note: MPS is already using Metal under the hood, no additional config needed
2526
logger_temp = logging.getLogger(__name__)
@@ -111,7 +112,7 @@ def get_autocast_context(device_type: str, dtype: torch.dtype):
111112
112113
PyTorch's autocast only supports 'cuda', 'cpu', and 'xpu' device types.
113114
For MPS (Apple Metal), autocast is not supported, so we use a nullcontext (no-op).
114-
Since MPS pipelines already use float32, no autocast is needed.
115+
MPS pipelines use float16 precision which is optimal for the hardware.
115116
116117
Args:
117118
device_type: Device type string ('cuda', 'cpu', 'mps', 'xpu')
@@ -121,7 +122,7 @@ def get_autocast_context(device_type: str, dtype: torch.dtype):
121122
Context manager for autocast or nullcontext for unsupported devices
122123
"""
123124
# torch.autocast doesn't support MPS device type
124-
# MPS pipelines already use float32, so autocast is not needed
125+
# MPS pipelines use float16 directly, which is optimal for the hardware
125126
if device_type == 'mps':
126127
return nullcontext()
127128

@@ -1537,7 +1538,6 @@ def patched_warmup(model, device_map, hf_quantizer):
15371538
os.environ["HF_HUB_DISABLE_CACHING_ALLOCATOR_WARMUP"] = "1"
15381539

15391540
try:
1540-
# MPS doesn't support bfloat16, use float32 instead
15411541
# IMPORTANT: For MPS, we need to use float16 (not float32) for optimal performance
15421542
# MPS has native support for float16 operations which are much faster than float32
15431543
print("[Apple Metal] Loading models with float16 precision for optimal MPS performance", flush=True)
@@ -1563,23 +1563,44 @@ def patched_warmup(model, device_map, hf_quantizer):
15631563
pipeline.mula_dtype = torch.float16
15641564
pipeline.codec_dtype = torch.float16
15651565

1566+
# Verify and correct model device placement
1567+
# Note: Accessing _mula and _codec (private attributes) is necessary here
1568+
# because the pipeline library doesn't provide public methods for device verification
15661569
if hasattr(pipeline, '_mula') and pipeline._mula is not None:
1567-
mula_device = next(pipeline._mula.parameters()).device
1568-
print(f"[Apple Metal] HeartMuLa model device: {mula_device}", flush=True)
1569-
if mula_device.type != 'mps':
1570-
logger.warning(f"[MPS] HeartMuLa model is on {mula_device}, not MPS! This will be slow.")
1571-
print(f"[Apple Metal] WARNING: HeartMuLa is on {mula_device}, moving to MPS...", flush=True)
1572-
pipeline._mula = pipeline._mula.to(mps_device)
1573-
print(f"[Apple Metal] HeartMuLa moved to MPS", flush=True)
1570+
try:
1571+
# Get first parameter's device, or handle case where model has no parameters
1572+
mula_params = list(pipeline._mula.parameters())
1573+
if mula_params:
1574+
mula_device = mula_params[0].device
1575+
print(f"[Apple Metal] HeartMuLa model device: {mula_device}", flush=True)
1576+
if mula_device.type != 'mps':
1577+
logger.warning(f"[MPS] HeartMuLa model is on {mula_device}, not MPS! This will be slow.")
1578+
print(f"[Apple Metal] WARNING: HeartMuLa is on {mula_device}, moving to MPS...", flush=True)
1579+
# Explicitly set both device and dtype for consistency
1580+
pipeline._mula = pipeline._mula.to(device=mps_device, dtype=torch.float16)
1581+
print(f"[Apple Metal] HeartMuLa moved to MPS with float16 precision", flush=True)
1582+
else:
1583+
logger.warning("[MPS] HeartMuLa model has no parameters - cannot verify device")
1584+
except Exception as e:
1585+
logger.warning(f"[MPS] Failed to verify HeartMuLa device: {e}")
15741586

15751587
if hasattr(pipeline, '_codec') and pipeline._codec is not None:
1576-
codec_device = next(pipeline._codec.parameters()).device
1577-
print(f"[Apple Metal] HeartCodec model device: {codec_device}", flush=True)
1578-
if codec_device.type != 'mps':
1579-
logger.warning(f"[MPS] HeartCodec model is on {codec_device}, not MPS! This will be slow.")
1580-
print(f"[Apple Metal] WARNING: HeartCodec is on {codec_device}, moving to MPS...", flush=True)
1581-
pipeline._codec = pipeline._codec.to(mps_device)
1582-
print(f"[Apple Metal] HeartCodec moved to MPS", flush=True)
1588+
try:
1589+
# Get first parameter's device, or handle case where model has no parameters
1590+
codec_params = list(pipeline._codec.parameters())
1591+
if codec_params:
1592+
codec_device = codec_params[0].device
1593+
print(f"[Apple Metal] HeartCodec model device: {codec_device}", flush=True)
1594+
if codec_device.type != 'mps':
1595+
logger.warning(f"[MPS] HeartCodec model is on {codec_device}, not MPS! This will be slow.")
1596+
print(f"[Apple Metal] WARNING: HeartCodec is on {codec_device}, moving to MPS...", flush=True)
1597+
# Explicitly set both device and dtype for consistency
1598+
pipeline._codec = pipeline._codec.to(device=mps_device, dtype=torch.float16)
1599+
print(f"[Apple Metal] HeartCodec moved to MPS with float16 precision", flush=True)
1600+
else:
1601+
logger.warning("[MPS] HeartCodec model has no parameters - cannot verify device")
1602+
except Exception as e:
1603+
logger.warning(f"[MPS] Failed to verify HeartCodec device: {e}")
15831604

15841605
print("[Apple Metal] MPS pipeline loaded successfully with float16 precision", flush=True)
15851606
print("[Apple Metal] All models are on MPS device for hardware acceleration", flush=True)

0 commit comments

Comments
 (0)