Skip to content

Commit 2b3b809

Browse files
committed
Update inference.py
1 parent b32b303 commit 2b3b809

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

optillm/inference.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -924,10 +924,13 @@ def _load_model():
924924
logger.info("Flash Attention 2 is not installed - falling back to default attention")
925925

926926
elif 'mps' in device:
927-
# MPS supports FP16
928-
model_kwargs["torch_dtype"] = torch.float16
929-
# model_kwargs["torch_dtype"] = torch.float32
930-
logger.info("Using MPS device with float16 precision")
927+
# Special handling for Gemma models which have NaN issues with float16 on MPS
928+
if 'gemma' in model_id.lower():
929+
model_kwargs["torch_dtype"] = torch.float32
930+
logger.info("Using MPS device with float32 for Gemma model (float16 causes NaN)")
931+
else:
932+
model_kwargs["torch_dtype"] = torch.float16
933+
logger.info("Using MPS device with float16 precision")
931934
else:
932935
# CPU can use FP16 if available
933936
if hasattr(torch.cpu, 'has_fp16') and torch.cpu.has_fp16:

0 commit comments

Comments
 (0)