Skip to content

Commit 8287454

Browse files
committed
fix mps
1 parent 333e752 commit 8287454

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

optillm/inference.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,8 +770,28 @@ def __init__(self, model_config: ModelConfig, cache_manager, device_manager, mod
770770

771771
self.tokenizer = self.setup_tokenizer(self.tokenizer)
772772

773+
# Handle token embedding resize with MPS device compatibility
773774
if self.base_model.get_input_embeddings().num_embeddings != len(self.tokenizer):
774-
self.base_model.resize_token_embeddings(len(self.tokenizer))
775+
try:
776+
self.base_model.resize_token_embeddings(len(self.tokenizer))
777+
except NotImplementedError as e:
778+
if "MPS" in str(e) and "linalg_cholesky_ex" in str(e):
779+
logger.warning("MPS device doesn't support token embedding resize operation. "
780+
"Temporarily moving to CPU for resize operation.")
781+
# Get current device
782+
original_device = next(self.base_model.parameters()).device
783+
784+
# Move model to CPU for resize operation
785+
self.base_model = self.base_model.cpu()
786+
self.base_model.resize_token_embeddings(len(self.tokenizer))
787+
788+
# Move model back to original device
789+
if original_device.type != 'cpu':
790+
self.base_model = self.base_model.to(original_device)
791+
logger.info(f"Model moved back to {original_device}")
792+
else:
793+
# Re-raise if it's a different NotImplementedError
794+
raise
775795

776796
self.current_model = self.base_model
777797

0 commit comments

Comments
 (0)