Skip to content

Commit 34b57c9

Browse files
committed
Revert "fix mps"
This reverts commit 8287454.
1 parent 8287454 commit 34b57c9

File tree

1 file changed

+1
-21
lines changed

1 file changed

+1
-21
lines changed

optillm/inference.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -770,28 +770,8 @@ def __init__(self, model_config: ModelConfig, cache_manager, device_manager, mod
770770

771771
self.tokenizer = self.setup_tokenizer(self.tokenizer)
772772

773-
# Handle token embedding resize with MPS device compatibility
774773
if self.base_model.get_input_embeddings().num_embeddings != len(self.tokenizer):
775-
try:
776-
self.base_model.resize_token_embeddings(len(self.tokenizer))
777-
except NotImplementedError as e:
778-
if "MPS" in str(e) and "linalg_cholesky_ex" in str(e):
779-
logger.warning("MPS device doesn't support token embedding resize operation. "
780-
"Temporarily moving to CPU for resize operation.")
781-
# Get current device
782-
original_device = next(self.base_model.parameters()).device
783-
784-
# Move model to CPU for resize operation
785-
self.base_model = self.base_model.cpu()
786-
self.base_model.resize_token_embeddings(len(self.tokenizer))
787-
788-
# Move model back to original device
789-
if original_device.type != 'cpu':
790-
self.base_model = self.base_model.to(original_device)
791-
logger.info(f"Model moved back to {original_device}")
792-
else:
793-
# Re-raise if it's a different NotImplementedError
794-
raise
774+
self.base_model.resize_token_embeddings(len(self.tokenizer))
795775

796776
self.current_model = self.base_model
797777

0 commit comments

Comments
 (0)