|
2 | 2 |
|
3 | 3 | import multiprocessing |
4 | 4 |
|
5 | | -from typing import Optional, List, Literal, Union |
6 | | -from pydantic import Field, root_validator |
| 5 | +from typing import Optional, List, Literal, Union, Dict, cast |
| 6 | +from typing_extensions import Self |
| 7 | + |
| 8 | +from pydantic import Field, model_validator |
7 | 9 | from pydantic_settings import BaseSettings |
8 | 10 |
|
9 | 11 | import llama_cpp |
@@ -173,15 +175,16 @@ class ModelSettings(BaseSettings): |
173 | 175 | default=True, description="Whether to print debug information." |
174 | 176 | ) |
175 | 177 |
|
176 | | - @root_validator(pre=True) # pre=True to ensure this runs before any other validation |
177 | | - def set_dynamic_defaults(cls, values): |
| 178 | + @model_validator(mode="before") # pre=True to ensure this runs before any other validation |
| 179 | + def set_dynamic_defaults(self) -> Self: |
178 | 180 | # If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count() |
179 | 181 | cpu_count = multiprocessing.cpu_count() |
| 182 | + values = cast(Dict[str, int], self) |
180 | 183 | if values.get('n_threads', 0) == -1: |
181 | 184 | values['n_threads'] = cpu_count |
182 | 185 | if values.get('n_threads_batch', 0) == -1: |
183 | 186 | values['n_threads_batch'] = cpu_count |
184 | | - return values |
| 187 | + return self |
185 | 188 |
|
186 | 189 |
|
187 | 190 | class ServerSettings(BaseSettings): |
|
0 commit comments