Skip to content

Commit 27f4c2f

Browse files
[Renderer] Separate out RendererConfig from ModelConfig (vllm-project#30145)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent a49d813 commit 27f4c2f

File tree

105 files changed

+971
-799
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

105 files changed

+971
-799
lines changed

docs/contributing/model/transcription.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Declare supported languages and capabilities:
2222
import torch
2323
from torch import nn
2424

25-
from vllm.config import ModelConfig, SpeechToTextConfig
25+
from vllm.config import RendererConfig, SpeechToTextConfig
2626
from vllm.inputs.data import PromptType
2727
from vllm.model_executor.models.interfaces import SupportsTranscription
2828

@@ -52,7 +52,7 @@ This is for controlling general behavior of the API when serving your model:
5252
@classmethod
5353
def get_speech_to_text_config(
5454
cls,
55-
model_config: ModelConfig,
55+
renderer_config: RendererConfig,
5656
task_type: Literal["transcribe", "translate"],
5757
) -> SpeechToTextConfig:
5858
return SpeechToTextConfig(
@@ -83,7 +83,7 @@ Return a dict containing `multi_modal_data` with the audio, and either a `prompt
8383
cls,
8484
audio: np.ndarray,
8585
stt_config: SpeechToTextConfig,
86-
model_config: ModelConfig,
86+
renderer_config: RendererConfig,
8787
language: str | None,
8888
task_type: Literal["transcribe", "translate"],
8989
request_prompt: str,
@@ -120,7 +120,7 @@ Return a dict with separate `encoder_prompt` and `decoder_prompt` entries:
120120
cls,
121121
audio: np.ndarray,
122122
stt_config: SpeechToTextConfig,
123-
model_config: ModelConfig,
123+
renderer_config: RendererConfig,
124124
language: str | None,
125125
task_type: Literal["transcribe", "translate"],
126126
request_prompt: str,
@@ -183,7 +183,7 @@ Provide a fast duration→token estimate to improve streaming usage statistics:
183183
cls,
184184
audio_duration_s: float,
185185
stt_config: SpeechToTextConfig,
186-
model_config: ModelConfig,
186+
renderer_config: RendererConfig,
187187
) -> int | None:
188188
# Return None if unknown; otherwise return an estimate.
189189
return int(audio_duration_s * stt_config.sample_rate // 320) # example
@@ -216,7 +216,7 @@ Relevant server logic:
216216
prompt = self.model_cls.get_generation_prompt(
217217
audio=chunk,
218218
stt_config=self.asr_config,
219-
model_config=self.model_config,
219+
renderer_config=self.renderer_config,
220220
language=language,
221221
task_type=self.task_type,
222222
request_prompt=request.prompt,

tests/compile/distributed/test_sequence_parallelism.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
DeviceConfig,
1818
ModelConfig,
1919
PassConfig,
20+
RendererConfig,
2021
VllmConfig,
2122
get_current_vllm_config,
2223
set_current_vllm_config,
@@ -276,6 +277,7 @@ def sequence_parallelism_pass_on_test_model(
276277

277278
vllm_config = VllmConfig(
278279
model_config=model_config,
280+
renderer_config=RendererConfig(model_config=model_config),
279281
device_config=device_config,
280282
compilation_config=compilation_config,
281283
)

tests/compile/test_functionalization.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
CompilationConfig,
1616
ModelConfig,
1717
PassConfig,
18+
RendererConfig,
1819
VllmConfig,
1920
set_current_vllm_config,
2021
)
@@ -219,8 +220,11 @@ def test_fix_functionalization(
219220
torch.set_default_device("cuda")
220221
torch.set_default_dtype(dtype)
221222

223+
model_config = ModelConfig(dtype=dtype)
224+
222225
vllm_config = VllmConfig(
223-
model_config=ModelConfig(dtype=dtype),
226+
model_config=model_config,
227+
renderer_config=RendererConfig(model_config=model_config),
224228
compilation_config=CompilationConfig(
225229
custom_ops=["all"],
226230
pass_config=PassConfig(

tests/compile/test_fusion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
CompilationMode,
1616
ModelConfig,
1717
PassConfig,
18+
RendererConfig,
1819
VllmConfig,
1920
)
2021
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -154,8 +155,11 @@ def test_fusion_rmsnorm_quant(
154155
custom_ops.append("+rms_norm")
155156
if enable_quant_fp8_custom_op:
156157
custom_ops.append("+quant_fp8")
158+
159+
model_config = ModelConfig(dtype=dtype)
157160
vllm_config = VllmConfig(
158-
model_config=ModelConfig(dtype=dtype),
161+
model_config=model_config,
162+
renderer_config=RendererConfig(model_config=model_config),
159163
compilation_config=CompilationConfig(
160164
mode=CompilationMode.VLLM_COMPILE,
161165
custom_ops=custom_ops,

tests/compile/test_fusion_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
CompilationMode,
2525
ModelConfig,
2626
PassConfig,
27+
RendererConfig,
2728
SchedulerConfig,
2829
VllmConfig,
2930
set_current_vllm_config,
@@ -325,6 +326,7 @@ def test_attention_quant_pattern(
325326
)
326327
vllm_config = VllmConfig(
327328
model_config=model_config,
329+
renderer_config=RendererConfig(model_config=model_config),
328330
scheduler_config=SchedulerConfig(
329331
max_num_seqs=1024,
330332
max_model_len=model_config.max_model_len,

tests/compile/test_pass_manager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
99
from vllm.compilation.pass_manager import PostGradPassManager
10-
from vllm.config import ModelConfig, VllmConfig
10+
from vllm.config import ModelConfig, RendererConfig, VllmConfig
1111

1212

1313
# dummy custom pass that doesn't inherit
@@ -43,7 +43,11 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None:
4343
)
4444
def test_pass_manager_uuid(callable):
4545
# Some passes need dtype to be set
46-
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
46+
model_config = ModelConfig(dtype=torch.bfloat16)
47+
config = VllmConfig(
48+
model_config=model_config,
49+
renderer_config=RendererConfig(model_config=model_config),
50+
)
4751

4852
pass_manager = PostGradPassManager()
4953
pass_manager.configure(config)

tests/compile/test_qk_norm_rope_fusion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
CompilationMode,
2020
ModelConfig,
2121
PassConfig,
22+
RendererConfig,
2223
VllmConfig,
2324
set_current_vllm_config,
2425
)
@@ -133,8 +134,10 @@ def test_qk_norm_rope_fusion(
133134
if enable_rope_custom_op:
134135
custom_ops.append("+rotary_embedding")
135136

137+
model_config = ModelConfig(dtype=dtype)
136138
vllm_config = VllmConfig(
137-
model_config=ModelConfig(dtype=dtype),
139+
model_config=model_config,
140+
renderer_config=RendererConfig(model_config=model_config),
138141
compilation_config=CompilationConfig(
139142
mode=CompilationMode.VLLM_COMPILE,
140143
custom_ops=custom_ops,

tests/distributed/test_kvlayout.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
DeviceConfig,
66
KVTransferConfig,
77
ModelConfig,
8+
RendererConfig,
89
VllmConfig,
910
set_current_vllm_config,
1011
)
@@ -47,6 +48,7 @@ def test_get_kv_connector_cache_layout_with_nixl_connector():
4748
vllm_config = VllmConfig(
4849
device_config=DeviceConfig("cpu"),
4950
model_config=model_config,
51+
renderer_config=RendererConfig(model_config=model_config),
5052
kv_transfer_config=kv_transfer_config,
5153
)
5254
with set_current_vllm_config(vllm_config):
@@ -70,6 +72,7 @@ def test_get_kv_connector_cache_layout_with_multi_connector():
7072
vllm_config = VllmConfig(
7173
device_config=DeviceConfig("cpu"),
7274
model_config=model_config,
75+
renderer_config=RendererConfig(model_config=model_config),
7376
kv_transfer_config=kv_transfer_config,
7477
)
7578
with set_current_vllm_config(vllm_config):

tests/entrypoints/openai/test_chat_template.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import pytest
55

6-
from vllm.config import ModelConfig
76
from vllm.entrypoints.chat_utils import apply_hf_chat_template, load_chat_template
87
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
98
from vllm.tokenizers import get_tokenizer
@@ -107,24 +106,11 @@ def test_get_gen_prompt(
107106
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
108107
model_info.check_available_online(on_fail="skip")
109108

110-
model_config = ModelConfig(
111-
model,
112-
tokenizer=model_info.tokenizer or model,
113-
tokenizer_mode=model_info.tokenizer_mode,
114-
trust_remote_code=model_info.trust_remote_code,
115-
revision=model_info.revision,
116-
hf_overrides=model_info.hf_overrides,
117-
skip_tokenizer_init=model_info.require_embed_inputs,
118-
enable_prompt_embeds=model_info.require_embed_inputs,
119-
enable_mm_embeds=model_info.require_embed_inputs,
120-
enforce_eager=model_info.enforce_eager,
121-
dtype=model_info.dtype,
122-
)
109+
renderer_config = model_info.build_renderer_config(model)
123110

124-
# Initialize the tokenizer
125111
tokenizer = get_tokenizer(
126-
tokenizer_name=model_config.tokenizer,
127-
trust_remote_code=model_config.trust_remote_code,
112+
renderer_config.tokenizer,
113+
trust_remote_code=renderer_config.trust_remote_code,
128114
)
129115
template_content = load_chat_template(chat_template=template)
130116

@@ -143,7 +129,7 @@ def test_get_gen_prompt(
143129
tokenizer=tokenizer,
144130
conversation=mock_request.messages,
145131
chat_template=mock_request.chat_template or template_content,
146-
model_config=model_config,
132+
renderer_config=renderer_config,
147133
tools=None,
148134
add_generation_prompt=mock_request.add_generation_prompt,
149135
continue_final_message=mock_request.continue_final_message,

tests/entrypoints/openai/test_lora_resolvers.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,34 @@ class MockModelConfig:
3333
"""Minimal mock ModelConfig for testing."""
3434

3535
model: str = MODEL_NAME
36-
tokenizer: str = MODEL_NAME
3736
trust_remote_code: bool = False
38-
tokenizer_mode: str = "auto"
3937
max_model_len: int = 100
40-
tokenizer_revision: str | None = None
4138
multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig)
4239
hf_config: MockHFConfig = field(default_factory=MockHFConfig)
4340
logits_processors: list[str] | None = None
4441
logits_processor_pattern: str | None = None
4542
diff_sampling_param: dict | None = None
46-
allowed_local_media_path: str = ""
47-
allowed_media_domains: list[str] | None = None
4843
encoder_config = None
4944
generation_config: str = "auto"
50-
skip_tokenizer_init: bool = False
5145

5246
def get_diff_sampling_param(self):
5347
return self.diff_sampling_param or {}
5448

5549

50+
@dataclass
51+
class MockRendererConfig:
52+
"""Minimal mock RendererConfig for testing."""
53+
54+
model_config: MockModelConfig
55+
56+
tokenizer: str = MODEL_NAME
57+
tokenizer_mode: str = "auto"
58+
tokenizer_revision: str | None = None
59+
skip_tokenizer_init: bool = False
60+
allowed_local_media_path: str = ""
61+
allowed_media_domains: list[str] | None = None
62+
63+
5664
class MockLoRAResolver(LoRAResolver):
5765
async def resolve_lora(
5866
self, base_model_name: str, lora_name: str
@@ -114,6 +122,7 @@ async def mock_generate(*args, **kwargs):
114122
mock_engine.add_lora.reset_mock()
115123

116124
mock_engine.model_config = MockModelConfig()
125+
mock_engine.renderer_config = MockRendererConfig(mock_engine.model_config)
117126
mock_engine.input_processor = MagicMock()
118127
mock_engine.io_processor = MagicMock()
119128

0 commit comments

Comments
 (0)