Skip to content

Commit e479143

Browse files
[Feature] Use pydantic validation in lora.py and load.py configs (vllm-project#26413)
Signed-off-by: simondanielsson <[email protected]>
1 parent e6e898f commit e479143

File tree

4 files changed

+48
-45
lines changed

4 files changed

+48
-45
lines changed

vllm/config/load.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import hashlib
5-
from dataclasses import field
65
from typing import TYPE_CHECKING, Any, Optional, Union
76

7+
from pydantic import Field, field_validator
88
from pydantic.dataclasses import dataclass
99

1010
from vllm.config.utils import config
@@ -64,15 +64,17 @@ class LoadConfig:
6464
was quantized using torchao and saved using safetensors.
6565
Needs torchao >= 0.14.0
6666
"""
67-
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
67+
model_loader_extra_config: Union[dict, TensorizerConfig] = Field(
6868
default_factory=dict
6969
)
7070
"""Extra config for model loader. This will be passed to the model loader
7171
corresponding to the chosen load_format."""
7272
device: Optional[str] = None
7373
"""Device to which model weights will be loaded, default to
7474
device_config.device"""
75-
ignore_patterns: Optional[Union[list[str], str]] = None
75+
ignore_patterns: Union[list[str], str] = Field(
76+
default_factory=lambda: ["original/**/*"]
77+
)
7678
"""The list of patterns to ignore when loading the model. Default to
7779
"original/**/*" to avoid repeated loading of llama's checkpoints."""
7880
use_tqdm_on_load: bool = True
@@ -107,12 +109,18 @@ def compute_hash(self) -> str:
107109
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
108110
return hash_str
109111

110-
def __post_init__(self):
111-
self.load_format = self.load_format.lower()
112-
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
112+
@field_validator("load_format", mode="after")
113+
def _lowercase_load_format(cls, load_format: str) -> str:
114+
return load_format.lower()
115+
116+
@field_validator("ignore_patterns", mode="after")
117+
def _validate_ignore_patterns(
118+
cls, ignore_patterns: Union[list[str], str]
119+
) -> Union[list[str], str]:
120+
if ignore_patterns != ["original/**/*"] and len(ignore_patterns) > 0:
113121
logger.info(
114122
"Ignoring the following patterns when downloading weights: %s",
115-
self.ignore_patterns,
123+
ignore_patterns,
116124
)
117-
else:
118-
self.ignore_patterns = ["original/**/*"]
125+
126+
return ignore_patterns

vllm/config/lora.py

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union
66

77
import torch
8-
from pydantic import ConfigDict
8+
from pydantic import ConfigDict, Field, model_validator
99
from pydantic.dataclasses import dataclass
10+
from typing_extensions import Self
1011

1112
import vllm.envs as envs
1213
from vllm.config.utils import config
@@ -23,16 +24,18 @@
2324
logger = init_logger(__name__)
2425

2526
LoRADType = Literal["auto", "float16", "bfloat16"]
27+
MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512]
28+
LoRAExtraVocabSize = Literal[256, 512]
2629

2730

2831
@config
2932
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
3033
class LoRAConfig:
3134
"""Configuration for LoRA."""
3235

33-
max_lora_rank: int = 16
36+
max_lora_rank: MaxLoRARanks = 16
3437
"""Max LoRA rank."""
35-
max_loras: int = 1
38+
max_loras: int = Field(default=1, ge=1)
3639
"""Max number of LoRAs in a single batch."""
3740
fully_sharded_loras: bool = False
3841
"""By default, only half of the LoRA computation is sharded with tensor
@@ -44,7 +47,14 @@ class LoRAConfig:
4447
`max_loras`."""
4548
lora_dtype: Union[torch.dtype, LoRADType] = "auto"
4649
"""Data type for LoRA. If auto, will default to base model dtype."""
47-
lora_extra_vocab_size: int = 256
50+
lora_extra_vocab_size: LoRAExtraVocabSize = Field(
51+
default=256,
52+
deprecated=(
53+
"`lora_extra_vocab_size` is deprecated and will be removed "
54+
"in v0.12.0. Additional vocabulary support for "
55+
"LoRA adapters is being phased out."
56+
),
57+
)
4858
"""(Deprecated) Maximum size of extra vocabulary that can be present in a
4959
LoRA adapter. Will be removed in v0.12.0."""
5060
lora_vocab_padding_size: ClassVar[int] = (
@@ -60,7 +70,10 @@ class LoRAConfig:
6070
per prompt. When run in offline mode, the lora IDs for n modalities
6171
will be automatically assigned to 1-n with the names of the modalities
6272
in alphabetic order."""
63-
bias_enabled: bool = False
73+
bias_enabled: bool = Field(
74+
default=False,
75+
deprecated="`bias_enabled` is deprecated and will be removed in v0.12.0.",
76+
)
6477
"""[DEPRECATED] Enable bias for LoRA adapters. This option will be
6578
removed in v0.12.0."""
6679

@@ -87,36 +100,8 @@ def compute_hash(self) -> str:
87100
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
88101
return hash_str
89102

90-
def __post_init__(self):
91-
# Deprecation warning for lora_extra_vocab_size
92-
logger.warning(
93-
"`lora_extra_vocab_size` is deprecated and will be removed "
94-
"in v0.12.0. Additional vocabulary support for "
95-
"LoRA adapters is being phased out."
96-
)
97-
98-
# Deprecation warning for enable_lora_bias
99-
if self.bias_enabled:
100-
logger.warning(
101-
"`enable_lora_bias` is deprecated and will be removed in v0.12.0."
102-
)
103-
104-
# Setting the maximum rank to 512 should be able to satisfy the vast
105-
# majority of applications.
106-
possible_max_ranks = (1, 8, 16, 32, 64, 128, 256, 320, 512)
107-
possible_lora_extra_vocab_size = (256, 512)
108-
if self.max_lora_rank not in possible_max_ranks:
109-
raise ValueError(
110-
f"max_lora_rank ({self.max_lora_rank}) must be one of "
111-
f"{possible_max_ranks}."
112-
)
113-
if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
114-
raise ValueError(
115-
f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
116-
f"must be one of {possible_lora_extra_vocab_size}."
117-
)
118-
if self.max_loras < 1:
119-
raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
103+
@model_validator(mode="after")
104+
def _validate_lora_config(self) -> Self:
120105
if self.max_cpu_loras is None:
121106
self.max_cpu_loras = self.max_loras
122107
elif self.max_cpu_loras < self.max_loras:
@@ -125,6 +110,8 @@ def __post_init__(self):
125110
f"max_loras ({self.max_loras})"
126111
)
127112

113+
return self
114+
128115
def verify_with_cache_config(self, cache_config: CacheConfig):
129116
if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1:
130117
raise ValueError("V0 LoRA does not support CPU offload, please use V1.")

vllm/config/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
1212

1313
import regex as re
14+
from pydantic.fields import FieldInfo
1415
from typing_extensions import runtime_checkable
1516

1617
if TYPE_CHECKING:
@@ -50,7 +51,14 @@ def get_field(cls: ConfigType, name: str) -> Field:
5051
if (default_factory := named_field.default_factory) is not MISSING:
5152
return field(default_factory=default_factory)
5253
if (default := named_field.default) is not MISSING:
54+
if isinstance(default, FieldInfo):
55+
# Handle pydantic.Field defaults
56+
if default.default_factory is not None:
57+
return field(default_factory=default.default_factory)
58+
else:
59+
default = default.default
5360
return field(default=default)
61+
5462
raise ValueError(
5563
f"{cls.__name__}.{name} must have a default value or default factory."
5664
)

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ class EngineArgs:
452452
num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override
453453
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
454454
model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
455-
ignore_patterns: Optional[Union[str, list[str]]] = LoadConfig.ignore_patterns
455+
ignore_patterns: Union[str, list[str]] = get_field(LoadConfig, "ignore_patterns")
456456

457457
enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill
458458
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input

0 commit comments

Comments
 (0)