Skip to content

Commit e3eb146

Browse files
authored
[Model Runner V2] Add ModelStateInterface [4/N] (vllm-project#35621)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
1 parent 95a395d commit e3eb146

File tree

5 files changed

+90
-4
lines changed

5 files changed

+90
-4
lines changed

vllm/v1/worker/gpu/cudagraph_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from vllm.v1.worker.gpu.block_table import BlockTables
2323
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
2424
from vllm.v1.worker.gpu.input_batch import InputBuffers
25-
from vllm.v1.worker.gpu.model_states import ModelState
25+
from vllm.v1.worker.gpu.model_states.interface import ModelState
2626
from vllm.v1.worker.utils import AttentionGroup
2727

2828

vllm/v1/worker/gpu/model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
)
7979
from vllm.v1.worker.gpu.lora_utils import LoraState
8080
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
81-
from vllm.v1.worker.gpu.model_states import ModelState
81+
from vllm.v1.worker.gpu.model_states import init_model_state
8282
from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner
8383
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
8484
from vllm.v1.worker.gpu.sample.output import SamplerOutput
@@ -267,7 +267,7 @@ def load_model(self, *args, **kwargs) -> None:
267267
prepare_communication_buffer_for_model(self.speculator)
268268

269269
# Initialize the components that require the model.
270-
self.model_state = ModelState(
270+
self.model_state = init_model_state(
271271
self.vllm_config, self.model, self.encoder_cache, self.device
272272
)
273273
if self.is_pooling_model:
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import torch
4+
import torch.nn as nn
5+
6+
from vllm.config import VllmConfig
7+
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
8+
9+
10+
def init_model_state(
11+
vllm_config: VllmConfig,
12+
model: nn.Module,
13+
encoder_cache: EncoderCache | None,
14+
device: torch.device,
15+
):
16+
from vllm.v1.worker.gpu.model_states.default import DefaultModelState
17+
18+
return DefaultModelState(vllm_config, model, encoder_cache, device)

vllm/v1/worker/gpu/model_states.py renamed to vllm/v1/worker/gpu/model_states/default.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
1414
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
1515
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
16+
from vllm.v1.worker.gpu.model_states.interface import ModelState
1617
from vllm.v1.worker.gpu.states import RequestState
1718
from vllm.v1.worker.utils import AttentionGroup
1819

1920

20-
class ModelState:
21+
class DefaultModelState(ModelState):
2122
def __init__(
2223
self,
2324
vllm_config: VllmConfig,
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from abc import ABC, abstractmethod
4+
from typing import Any
5+
6+
import torch
7+
import torch.nn as nn
8+
9+
from vllm.config import VllmConfig
10+
from vllm.v1.core.sched.output import NewRequestData
11+
from vllm.v1.kv_cache_interface import KVCacheConfig
12+
from vllm.v1.worker.gpu.input_batch import InputBatch
13+
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
14+
from vllm.v1.worker.gpu.states import RequestState
15+
from vllm.v1.worker.utils import AttentionGroup
16+
17+
18+
class ModelState(ABC):
19+
@abstractmethod
20+
def __init__(
21+
self,
22+
vllm_config: VllmConfig,
23+
model: nn.Module,
24+
encoder_cache: EncoderCache | None,
25+
device: torch.device,
26+
) -> None:
27+
raise NotImplementedError
28+
29+
@abstractmethod
30+
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
31+
raise NotImplementedError
32+
33+
@abstractmethod
34+
def apply_staged_writes(self) -> None:
35+
raise NotImplementedError
36+
37+
@abstractmethod
38+
def get_mm_embeddings(
39+
self,
40+
scheduled_encoder_inputs: dict[str, list[int]],
41+
input_batch: InputBatch,
42+
req_states: RequestState,
43+
) -> torch.Tensor:
44+
raise NotImplementedError
45+
46+
@abstractmethod
47+
def prepare_inputs(
48+
self, input_batch: InputBatch, req_states: RequestState
49+
) -> dict[str, torch.Tensor | None]:
50+
raise NotImplementedError
51+
52+
@abstractmethod
53+
def prepare_dummy_inputs(
54+
self, num_reqs: int, num_tokens: int
55+
) -> dict[str, torch.Tensor | None]:
56+
raise NotImplementedError
57+
58+
@abstractmethod
59+
def prepare_attn(
60+
self,
61+
input_batch: InputBatch,
62+
block_tables: tuple[torch.Tensor, ...],
63+
slot_mappings: torch.Tensor,
64+
attn_groups: list[list[AttentionGroup]],
65+
kv_cache_config: KVCacheConfig,
66+
) -> dict[str, Any]:
67+
raise NotImplementedError

0 commit comments

Comments
 (0)