Skip to content

Commit 934bebf

Browse files
authored
Better errors for Transformers backend missing features (vllm-project#23759)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 885ca6d commit 934bebf

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

vllm/model_executor/models/transformers.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""Wrapper around `transformers` models"""
1818
from collections.abc import Iterable, Mapping
1919
from contextlib import contextmanager
20+
from pathlib import Path
2021
from typing import Literal, Optional, Union
2122

2223
import regex as re
@@ -60,6 +61,21 @@
6061
logger = init_logger(__name__)
6162

6263

64+
def get_feature_request_tip(
65+
model: str,
66+
trust_remote_code: bool,
67+
) -> str:
68+
hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
69+
gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
70+
url = hf_url if trust_remote_code else gh_url
71+
prefix = f"Please open {url} to request support for this feature. "
72+
if Path(model).exists():
73+
prefix = ""
74+
doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
75+
tip = f"See {doc_url} for instructions on how to add support yourself."
76+
return f"{prefix}{tip}"
77+
78+
6379
def vllm_flash_attention_forward(
6480
# Transformers args
6581
module: torch.nn.Module,
@@ -480,8 +496,11 @@ def pipeline_parallel(self):
480496
return
481497

482498
if not self.model.supports_pp_plan:
499+
tip = get_feature_request_tip(self.model_config.model,
500+
self.model_config.trust_remote_code)
483501
raise ValueError(
484-
f"{type(self.model)} does not support pipeline parallel yet!")
502+
f"{type(self.model)} does not support pipeline parallel. {tip}"
503+
)
485504

486505
module_lists = []
487506
module_list_idx = None
@@ -535,8 +554,10 @@ def tensor_parallel(self):
535554
models_with_tp_plan = filter(supports_tp_plan, pretrained_models)
536555

537556
if not any(models_with_tp_plan) and self.tp_size > 1:
557+
tip = get_feature_request_tip(self.model_config.model,
558+
self.model_config.trust_remote_code)
538559
raise ValueError(
539-
f"{type(self.model)} does not support tensor parallel yet!")
560+
f"{type(self.model)} does not support tensor parallel. {tip}")
540561

541562
def _tensor_parallel(module: nn.Module,
542563
prefix: str = "",

0 commit comments

Comments
 (0)