Skip to content

Commit 53121d6

Browse files
authored
switch to pydantic v2 (#485)
Closes #483.
1 parent 4a1a360 commit 53121d6

File tree

15 files changed

+104
-142
lines changed

15 files changed

+104
-142
lines changed

gto/_pydantic.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

gto/base.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import datetime
22
from typing import Any, Dict, FrozenSet, List, Optional, Sequence, Union
33

4+
from pydantic import BaseModel, ConfigDict
45
from scmrepo.git import Git
56

67
from gto.config import RegistryConfig
@@ -12,7 +13,6 @@
1213
)
1314
from gto.versions import SemVer
1415

15-
from ._pydantic import BaseModel
1616
from .exceptions import (
1717
ArtifactNotFound,
1818
ManyVersions,
@@ -41,7 +41,7 @@ def event(self):
4141
return self.__class__.__name__.lower()
4242

4343
def dict_state(self, exclude=None):
44-
state = self.dict(exclude=exclude)
44+
state = self.model_dump(exclude=exclude)
4545
state["event"] = self.event
4646
return state
4747

@@ -178,7 +178,7 @@ def ref(self):
178178
return self.authoring_event.ref
179179

180180
def dict_state(self, exclude=None):
181-
version = self.dict(exclude=exclude)
181+
version = self.model_dump(exclude=exclude)
182182
version["is_active"] = self.is_active
183183
version["activated_at"] = self.activated_at
184184
version["created_at"] = self.created_at
@@ -565,9 +565,7 @@ def find_version_at_commit(
565565

566566
class BaseRegistryState(BaseModel):
567567
artifacts: Dict[str, Artifact] = {}
568-
569-
class Config:
570-
arbitrary_types_allowed = True
568+
model_config = ConfigDict(arbitrary_types_allowed=True)
571569

572570
def add_artifact(self, name):
573571
self.artifacts[name] = Artifact(artifact=name, versions=[])
@@ -623,9 +621,7 @@ class BaseManager(BaseModel):
623621
scm: Git
624622
actions: FrozenSet[Action]
625623
config: RegistryConfig
626-
627-
class Config:
628-
arbitrary_types_allowed = True
624+
model_config = ConfigDict(arbitrary_types_allowed=True)
629625

630626
def update_state(self, state: BaseRegistryState) -> BaseRegistryState:
631627
raise NotImplementedError

gto/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ def stages(
810810
def print_state(repo: str = option_repo):
811811
"""Technical cmd: Print current registry state."""
812812
state = make_ready_to_serialize(
813-
gto.api._get_state(repo).dict() # pylint: disable=protected-access
813+
gto.api._get_state(repo).model_dump() # pylint: disable=protected-access
814814
)
815815
format_echo(state, "json")
816816

@@ -833,7 +833,7 @@ def doctor(
833833
echo(f"{EMOJI_FAIL} Fail to parse config")
834834
echo("---------------------------------")
835835

836-
gto.api._get_state(repo).dict() # pylint: disable=protected-access
836+
gto.api._get_state(repo).model_dump() # pylint: disable=protected-access
837837
with cli_echo():
838838
echo(f"{EMOJI_OK} No issues found")
839839

gto/config.py

Lines changed: 54 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
# pylint: disable=no-self-argument, inconsistent-return-statements, invalid-name, import-outside-toplevel
22
import pathlib
3-
from pathlib import Path
43
from typing import Any, Dict, List, Optional
54

5+
from pydantic import BaseModel, Field, field_validator
6+
from pydantic_settings import (
7+
BaseSettings,
8+
PydanticBaseSettingsSource,
9+
SettingsConfigDict,
10+
)
11+
from pydantic_settings import (
12+
YamlConfigSettingsSource as _YamlConfigSettingsSource,
13+
)
614
from ruamel.yaml import YAML
715

816
from gto.constants import assert_name_is_valid
917
from gto.exceptions import UnknownStage, UnknownType, WrongConfig
1018
from gto.ext import EnrichmentReader, find_enrichment_types, find_enrichments
1119

12-
from ._pydantic import BaseModel, BaseSettings, InitSettingsSource, validator
13-
1420
yaml = YAML(typ="safe", pure=True)
1521
yaml.default_flow_style = False
1622

@@ -27,45 +33,47 @@ def load(self) -> EnrichmentReader:
2733

2834
class NoFileConfig(BaseSettings): # type: ignore[valid-type]
2935
INDEX: str = "artifacts.yaml"
30-
TYPES: Optional[List[str]] = None
31-
STAGES: Optional[List[str]] = None
36+
CONFIG_FILE_NAME: Optional[str] = CONFIG_FILE_NAME
3237
LOG_LEVEL: str = "INFO"
3338
DEBUG: bool = False
34-
ENRICHMENTS: List[EnrichmentConfig] = []
35-
AUTOLOAD_ENRICHMENTS: bool = True
36-
CONFIG_FILE_NAME: Optional[str] = CONFIG_FILE_NAME
3739
EMOJIS: bool = True
3840

39-
class Config:
40-
env_prefix = "gto_"
41+
types: Optional[List[str]] = None
42+
stages: Optional[List[str]] = None
43+
enrichments: List[EnrichmentConfig] = Field(default_factory=list)
44+
autoload_enrichments: bool = True
45+
46+
model_config = SettingsConfigDict(env_prefix="gto_")
4147

4248
def assert_type(self, name):
4349
assert_name_is_valid(name)
4450
# pylint: disable-next=unsupported-membership-test
45-
if self.TYPES is not None and name not in self.TYPES:
46-
raise UnknownType(name, self.TYPES)
51+
if self.types is not None and name not in self.types:
52+
raise UnknownType(name, self.types)
4753

4854
def assert_stage(self, name):
4955
assert_name_is_valid(name)
5056
# pylint: disable-next=unsupported-membership-test
51-
if self.STAGES is not None and name not in self.STAGES:
52-
raise UnknownStage(name, self.STAGES)
57+
if self.stages is not None and name not in self.stages:
58+
raise UnknownStage(name, self.stages)
5359

5460
@property
55-
def enrichments(self) -> Dict[str, EnrichmentReader]:
56-
res = {e.source: e for e in (e.load() for e in self.ENRICHMENTS)}
57-
if self.AUTOLOAD_ENRICHMENTS:
61+
def enrichments_(self) -> Dict[str, EnrichmentReader]:
62+
res = {e.source: e for e in (e.load() for e in self.enrichments)}
63+
if self.autoload_enrichments:
5864
return {**find_enrichments(), **res}
5965
return res
6066

61-
@validator("TYPES")
67+
@field_validator("types")
68+
@classmethod
6269
def types_are_valid(cls, v): # pylint: disable=no-self-use
6370
if v:
6471
for name in v:
6572
assert_name_is_valid(name)
6673
return v
6774

68-
@validator("STAGES")
75+
@field_validator("stages")
76+
@classmethod
6977
def stages_are_valid(cls, v): # pylint: disable=no-self-use
7078
if v:
7179
for name in v:
@@ -77,61 +85,48 @@ def check_index_exist(self, repo: str):
7785
return index.exists() and index.is_file()
7886

7987

80-
def _set_location_init_source(init_source: InitSettingsSource):
81-
def inner(settings: "RegistryConfig"):
82-
if "CONFIG_FILE_NAME" in init_source.init_kwargs:
83-
settings.__dict__["CONFIG_FILE_NAME"] = init_source.init_kwargs[
84-
"CONFIG_FILE_NAME"
85-
]
86-
return {}
87-
88-
return inner
88+
class YamlConfigSettingsSource(_YamlConfigSettingsSource):
89+
def _read_file(self, file_path: pathlib.Path) -> dict[str, Any]:
90+
with open(file_path, encoding=self.yaml_file_encoding) as yaml_file:
91+
return yaml.load(yaml_file) or {}
8992

9093

91-
def config_settings_source(settings: "RegistryConfig") -> Dict[str, Any]:
92-
"""
93-
A simple settings source that loads variables from a yaml file in GTO DIR
94-
"""
95-
96-
encoding = settings.__config__.env_file_encoding
97-
config_file = getattr(settings, "CONFIG_FILE_NAME", CONFIG_FILE_NAME)
98-
if not isinstance(config_file, Path):
99-
config_file = Path(config_file)
100-
if not config_file.exists():
101-
return {}
102-
conf = yaml.load(config_file.read_text(encoding=encoding))
103-
104-
return {k.upper(): v for k, v in conf.items()} if conf else {}
94+
class RegistryConfig(NoFileConfig):
95+
model_config = SettingsConfigDict(env_prefix="gto_", env_file_encoding="utf-8")
10596

97+
def config_file_exists(self):
98+
config = pathlib.Path(self.CONFIG_FILE_NAME)
99+
return config.exists() and config.is_file()
106100

107-
class RegistryConfig(NoFileConfig):
108-
class Config:
109-
env_prefix = "gto_"
110-
env_file_encoding = "utf-8"
111101

102+
def read_registry_config(config_file_name) -> "RegistryConfig":
103+
class _RegistryConfig(RegistryConfig):
112104
@classmethod
113-
def customise_sources(
105+
def settings_customise_sources(
114106
cls,
115-
init_settings,
116-
env_settings,
117-
file_secret_settings,
107+
settings_cls: type[BaseSettings],
108+
init_settings: PydanticBaseSettingsSource,
109+
env_settings: PydanticBaseSettingsSource,
110+
dotenv_settings: PydanticBaseSettingsSource,
111+
file_secret_settings: PydanticBaseSettingsSource,
118112
):
113+
encoding = getattr(settings_cls.model_config, "env_file_encoding", "utf-8")
119114
return (
120-
_set_location_init_source(init_settings),
121115
init_settings,
122116
env_settings,
123-
config_settings_source,
117+
(
118+
YamlConfigSettingsSource(
119+
settings_cls,
120+
yaml_file=config_file_name,
121+
yaml_file_encoding=encoding,
122+
)
123+
),
124+
dotenv_settings,
124125
file_secret_settings,
125126
)
126127

127-
def config_file_exists(self):
128-
config = pathlib.Path(self.CONFIG_FILE_NAME)
129-
return config.exists() and config.is_file()
130-
131-
132-
def read_registry_config(config_file_name):
133128
try:
134-
return RegistryConfig(CONFIG_FILE_NAME=config_file_name)
129+
return _RegistryConfig(CONFIG_FILE_NAME=config_file_name)
135130
except Exception as e: # pylint: disable=bare-except
136131
raise WrongConfig(config_file_name) from e
137132

gto/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from enum import Enum
33
from typing import Optional
44

5-
from gto.exceptions import ValidationError
5+
from pydantic import BaseModel
66

7-
from ._pydantic import BaseModel
7+
from gto.exceptions import ValidationError
88

99
COMMIT = "commit"
1010
REF = "ref"

gto/ext.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from typing import Dict, Optional, Type, Union
55

66
import entrypoints
7+
from pydantic import BaseModel
78
from scmrepo.git import Git
89

9-
from ._pydantic import BaseModel
10-
1110
ENRICHMENT_ENTRYPOINT = "gto.enrichment"
1211

1312

@@ -29,7 +28,7 @@ def get_object(self) -> BaseModel:
2928
raise NotImplementedError
3029

3130
def get_dict(self):
32-
return self.get_object().dict()
31+
return self.get_object().model_dump()
3332

3433
@abstractmethod
3534
def get_human_readable(self) -> str:

0 commit comments

Comments
 (0)