1717from tokenizers import Tokenizer
1818
1919# Configure MPS (Apple Metal) for optimal performance
20- # These settings must be set before any PyTorch operations
20+ # Note: PYTORCH_ENABLE_MPS_FALLBACK can be set at runtime for fallback behavior
2121if hasattr (torch .backends , 'mps' ) and torch .backends .mps .is_available ():
2222 # Enable MPS fallback to CPU for unsupported operations (better than crashing)
23+ # This takes effect for subsequent tensor operations
2324 os .environ .setdefault ("PYTORCH_ENABLE_MPS_FALLBACK" , "1" )
2425 # Note: MPS is already using Metal under the hood, no additional config needed
2526 logger_temp = logging .getLogger (__name__ )
@@ -111,7 +112,7 @@ def get_autocast_context(device_type: str, dtype: torch.dtype):
111112
112113 PyTorch's autocast only supports 'cuda', 'cpu', and 'xpu' device types.
113114 For MPS (Apple Metal), autocast is not supported, so we use a nullcontext (no-op).
114- Since MPS pipelines already use float32, no autocast is needed .
115+ MPS pipelines use float16 precision which is optimal for the hardware .
115116
116117 Args:
117118 device_type: Device type string ('cuda', 'cpu', 'mps', 'xpu')
@@ -121,7 +122,7 @@ def get_autocast_context(device_type: str, dtype: torch.dtype):
121122 Context manager for autocast or nullcontext for unsupported devices
122123 """
123124 # torch.autocast doesn't support MPS device type
124- # MPS pipelines already use float32, so autocast is not needed
125+ # MPS pipelines use float16 directly, which is optimal for the hardware
125126 if device_type == 'mps' :
126127 return nullcontext ()
127128
@@ -1537,7 +1538,6 @@ def patched_warmup(model, device_map, hf_quantizer):
15371538 os .environ ["HF_HUB_DISABLE_CACHING_ALLOCATOR_WARMUP" ] = "1"
15381539
15391540 try :
1540- # MPS doesn't support bfloat16, use float32 instead
15411541 # IMPORTANT: For MPS, we need to use float16 (not float32) for optimal performance
15421542 # MPS has native support for float16 operations which are much faster than float32
15431543 print ("[Apple Metal] Loading models with float16 precision for optimal MPS performance" , flush = True )
@@ -1563,23 +1563,44 @@ def patched_warmup(model, device_map, hf_quantizer):
15631563 pipeline .mula_dtype = torch .float16
15641564 pipeline .codec_dtype = torch .float16
15651565
1566+ # Verify and correct model device placement
1567+ # Note: Accessing _mula and _codec (private attributes) is necessary here
1568+ # because the pipeline library doesn't provide public methods for device verification
15661569 if hasattr (pipeline , '_mula' ) and pipeline ._mula is not None :
1567- mula_device = next (pipeline ._mula .parameters ()).device
1568- print (f"[Apple Metal] HeartMuLa model device: { mula_device } " , flush = True )
1569- if mula_device .type != 'mps' :
1570- logger .warning (f"[MPS] HeartMuLa model is on { mula_device } , not MPS! This will be slow." )
1571- print (f"[Apple Metal] WARNING: HeartMuLa is on { mula_device } , moving to MPS..." , flush = True )
1572- pipeline ._mula = pipeline ._mula .to (mps_device )
1573- print (f"[Apple Metal] HeartMuLa moved to MPS" , flush = True )
1570+ try :
1571+ # Get first parameter's device, or handle case where model has no parameters
1572+ mula_params = list (pipeline ._mula .parameters ())
1573+ if mula_params :
1574+ mula_device = mula_params [0 ].device
1575+ print (f"[Apple Metal] HeartMuLa model device: { mula_device } " , flush = True )
1576+ if mula_device .type != 'mps' :
1577+ logger .warning (f"[MPS] HeartMuLa model is on { mula_device } , not MPS! This will be slow." )
1578+ print (f"[Apple Metal] WARNING: HeartMuLa is on { mula_device } , moving to MPS..." , flush = True )
1579+ # Explicitly set both device and dtype for consistency
1580+ pipeline ._mula = pipeline ._mula .to (device = mps_device , dtype = torch .float16 )
1581+ print (f"[Apple Metal] HeartMuLa moved to MPS with float16 precision" , flush = True )
1582+ else :
1583+ logger .warning ("[MPS] HeartMuLa model has no parameters - cannot verify device" )
1584+ except Exception as e :
1585+ logger .warning (f"[MPS] Failed to verify HeartMuLa device: { e } " )
15741586
15751587 if hasattr (pipeline , '_codec' ) and pipeline ._codec is not None :
1576- codec_device = next (pipeline ._codec .parameters ()).device
1577- print (f"[Apple Metal] HeartCodec model device: { codec_device } " , flush = True )
1578- if codec_device .type != 'mps' :
1579- logger .warning (f"[MPS] HeartCodec model is on { codec_device } , not MPS! This will be slow." )
1580- print (f"[Apple Metal] WARNING: HeartCodec is on { codec_device } , moving to MPS..." , flush = True )
1581- pipeline ._codec = pipeline ._codec .to (mps_device )
1582- print (f"[Apple Metal] HeartCodec moved to MPS" , flush = True )
1588+ try :
1589+ # Get first parameter's device, or handle case where model has no parameters
1590+ codec_params = list (pipeline ._codec .parameters ())
1591+ if codec_params :
1592+ codec_device = codec_params [0 ].device
1593+ print (f"[Apple Metal] HeartCodec model device: { codec_device } " , flush = True )
1594+ if codec_device .type != 'mps' :
1595+ logger .warning (f"[MPS] HeartCodec model is on { codec_device } , not MPS! This will be slow." )
1596+ print (f"[Apple Metal] WARNING: HeartCodec is on { codec_device } , moving to MPS..." , flush = True )
1597+ # Explicitly set both device and dtype for consistency
1598+ pipeline ._codec = pipeline ._codec .to (device = mps_device , dtype = torch .float16 )
1599+ print (f"[Apple Metal] HeartCodec moved to MPS with float16 precision" , flush = True )
1600+ else :
1601+ logger .warning ("[MPS] HeartCodec model has no parameters - cannot verify device" )
1602+ except Exception as e :
1603+ logger .warning (f"[MPS] Failed to verify HeartCodec device: { e } " )
15831604
15841605 print ("[Apple Metal] MPS pipeline loaded successfully with float16 precision" , flush = True )
15851606 print ("[Apple Metal] All models are on MPS device for hardware acceleration" , flush = True )
0 commit comments