File tree Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -924,10 +924,13 @@ def _load_model():
924924 logger .info ("Flash Attention 2 is not installed - falling back to default attention" )
925925
926926 elif 'mps' in device :
927- # MPS supports FP16
928- model_kwargs ["torch_dtype" ] = torch .float16
929- # model_kwargs["torch_dtype"] = torch.float32
930- logger .info ("Using MPS device with float16 precision" )
927+ # Special handling for Gemma models which have NaN issues with float16 on MPS
928+ if 'gemma' in model_id .lower ():
929+ model_kwargs ["torch_dtype" ] = torch .float32
930+ logger .info ("Using MPS device with float32 for Gemma model (float16 causes NaN)" )
931+ else :
932+ model_kwargs ["torch_dtype" ] = torch .float16
933+ logger .info ("Using MPS device with float16 precision" )
931934 else :
932935 # CPU can use FP16 if available
933936 if hasattr (torch .cpu , 'has_fp16' ) and torch .cpu .has_fp16 :
You can’t perform that action at this time.
0 commit comments