|
| 1 | +from typing import Any, Dict, Set |
| 2 | + |
| 3 | +from pydantic import BaseModel, Field |
| 4 | + |
| 5 | + |
| 6 | +class BaseConfig(BaseModel): |
| 7 | + class Config: |
| 8 | + extra = "allow" # Allow arbitrary attributes |
| 9 | + |
| 10 | + def __init__(self, **data: Any) -> None: |
| 11 | + super().__init__(**data) |
| 12 | + self.__dict__["_list_fields"]: Set[str] = set() |
| 13 | + self.__dict__["_alias"]: Dict[str, str] = {} |
| 14 | + |
| 15 | + def __getitem__(self, key: str) -> Any: |
| 16 | + return getattr(self, key) |
| 17 | + |
| 18 | + def __setitem__(self, key: str, value: Any): |
| 19 | + setattr(self, key, value) |
| 20 | + |
| 21 | + def __getattr__(self, name): |
| 22 | + """Handles alias access and custom parameters.""" |
| 23 | + if name in self._alias: |
| 24 | + return getattr(self, self._alias[name]) |
| 25 | + |
| 26 | + def __setattr__(self, name, value): |
| 27 | + """Handles alias assignment, field setting, or adding to _param.""" |
| 28 | + if name in self._alias: |
| 29 | + name = self._alias[name] |
| 30 | + if name in self._list_fields and not isinstance(value, list): |
| 31 | + value = [value] |
| 32 | + super().__setattr__(name, value) |
| 33 | + |
| 34 | + def __contains__(self, key: str) -> bool: |
| 35 | + return hasattr(self, key) |
| 36 | + |
| 37 | + def __repr__(self): |
| 38 | + attrs = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} |
| 39 | + attr_str = "\n".join(f" {key}: {value!r}" for key, value in attrs.items()) |
| 40 | + return f"{self.__class__.__name__}(\n{attr_str}\n)" |
| 41 | + |
| 42 | + def set_alias(self, name: str, alias: str) -> None: |
| 43 | + self.__dict__["_alias"][alias] = name |
| 44 | + |
| 45 | + def ensure_list(self, name: str): |
| 46 | + """Mark the field to always be treated as a list""" |
| 47 | + value = getattr(self, name, None) |
| 48 | + if value is not None and not isinstance(value, list): |
| 49 | + setattr(self, name, [value]) |
| 50 | + self._list_fields.add(name) |
| 51 | + |
| 52 | + |
| 53 | +class Foo(BaseConfig): |
| 54 | + a: int = 1 |
| 55 | + |
| 56 | + class Config: |
| 57 | + extra = "allow" |
| 58 | + |
| 59 | + |
| 60 | +print(Foo(**{"a": 1, "b": 2}).model_dump()) # == {'a': 1, 'b': 2} |
| 61 | + |
| 62 | +foo = Foo() |
| 63 | +foo.b = 2 |
| 64 | +print(foo.model_dump()) |
0 commit comments