|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 |
| - |
| 3 | +"""Utility functions for vLLM config dataclasses.""" |
4 | 4 | import ast
|
5 | 5 | import inspect
|
6 | 6 | import textwrap
|
7 |
| -from dataclasses import MISSING, Field, field, fields, is_dataclass |
8 |
| -from typing import TYPE_CHECKING, Any, TypeVar |
| 7 | +from dataclasses import MISSING, Field, field, fields, is_dataclass, replace |
| 8 | +from typing import TYPE_CHECKING, Any, Protocol, TypeVar |
9 | 9 |
|
10 | 10 | import regex as re
|
| 11 | +from typing_extensions import runtime_checkable |
11 | 12 |
|
12 | 13 | if TYPE_CHECKING:
|
13 | 14 | from _typeshed import DataclassInstance
|
14 |
| - |
15 |
| - ConfigType = type[DataclassInstance] |
16 | 15 | else:
|
17 |
| - ConfigType = type |
| 16 | + DataclassInstance = Any |
18 | 17 |
|
| 18 | +ConfigType = type[DataclassInstance] |
19 | 19 | ConfigT = TypeVar("ConfigT", bound=ConfigType)
|
20 | 20 |
|
21 | 21 |
|
@@ -143,3 +143,33 @@ def pairwise(iterable):
|
143 | 143 |
|
144 | 144 | def is_init_field(cls: ConfigType, name: str) -> bool:
|
145 | 145 | return next(f for f in fields(cls) if f.name == name).init
|
| 146 | + |
| 147 | + |
| 148 | +@runtime_checkable |
| 149 | +class SupportsHash(Protocol): |
| 150 | + |
| 151 | + def compute_hash(self) -> str: |
| 152 | + ... |
| 153 | + |
| 154 | + |
| 155 | +class SupportsMetricsInfo(Protocol): |
| 156 | + |
| 157 | + def metrics_info(self) -> dict[str, str]: |
| 158 | + ... |
| 159 | + |
| 160 | + |
| 161 | +def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT: |
| 162 | + processed_overrides = {} |
| 163 | + for field_name, value in overrides.items(): |
| 164 | + assert hasattr( |
| 165 | + config, field_name), f"{type(config)} has no field `{field_name}`" |
| 166 | + current_value = getattr(config, field_name) |
| 167 | + if is_dataclass(current_value) and not is_dataclass(value): |
| 168 | + assert isinstance(value, dict), ( |
| 169 | + f"Overrides to {type(config)}.{field_name} must be a dict" |
| 170 | + f" or {type(current_value)}, but got {type(value)}") |
| 171 | + value = update_config( |
| 172 | + current_value, # type: ignore[type-var] |
| 173 | + value) |
| 174 | + processed_overrides[field_name] = value |
| 175 | + return replace(config, **processed_overrides) |
0 commit comments