Skip to content

Commit 26d97dd

Browse files
joerundedbyoung18
authored andcommitted
[Hardware] Add processor inputs to platform validation (vllm-project#16680)
Signed-off-by: Joe Runde <[email protected]>
1 parent 64914a8 commit 26d97dd

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

vllm/platforms/interface.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import torch
1010

11-
from vllm.inputs import PromptType
11+
from vllm.inputs import ProcessorInputs, PromptType
1212
from vllm.logger import init_logger
1313

1414
if TYPE_CHECKING:
@@ -401,6 +401,7 @@ def validate_request(
401401
cls,
402402
prompt: PromptType,
403403
params: Union[SamplingParams, PoolingParams],
404+
processed_inputs: ProcessorInputs,
404405
) -> None:
405406
"""Raises if this request is unsupported on this platform"""
406407

vllm/platforms/tpu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66

77
import vllm.envs as envs
8-
from vllm.inputs import PromptType
8+
from vllm.inputs import ProcessorInputs, PromptType
99
from vllm.logger import init_logger
1010
from vllm.sampling_params import SamplingParams, SamplingType
1111

@@ -150,6 +150,7 @@ def validate_request(
150150
cls,
151151
prompt: PromptType,
152152
params: Union[SamplingParams, PoolingParams],
153+
processed_inputs: ProcessorInputs,
153154
) -> None:
154155
"""Raises if this request is unsupported on this platform"""
155156
if isinstance(params, SamplingParams):

vllm/v1/engine/processor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,6 @@ def process_inputs(
202202

203203
# TODO(woosuk): Support pooling models.
204204
# TODO(woosuk): Support encoder-decoder models.
205-
206-
from vllm.platforms import current_platform
207-
current_platform.validate_request(
208-
prompt=prompt,
209-
params=params,
210-
)
211205
self._validate_lora(lora_request)
212206
self._validate_params(params)
213207
if priority != 0:
@@ -231,6 +225,12 @@ def process_inputs(
231225
prompt_adapter_request=prompt_adapter_request,
232226
return_mm_hashes=self.use_hash,
233227
)
228+
from vllm.platforms import current_platform
229+
current_platform.validate_request(
230+
prompt=prompt,
231+
params=params,
232+
processed_inputs=processed_inputs,
233+
)
234234
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
235235

236236
self._validate_model_inputs(processed_inputs, lora_request)

0 commit comments

Comments
 (0)