Skip to content

Commit 0837772

Browse files
authored
Merge pull request #692 from maresb/migrate-pydantic
Code migrations for Pydantic v2
2 parents 41d3451 + 345fbcb commit 0837772

File tree

6 files changed

+42
-32
lines changed

6 files changed

+42
-32
lines changed

conda_lock/lockfile/v1/models.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import json
55
import logging
66
import pathlib
7-
import typing
87

98
from collections import namedtuple
109
from typing import (
@@ -22,7 +21,7 @@
2221
if TYPE_CHECKING:
2322
from hashlib import _Hash
2423

25-
from pydantic import Field, validator
24+
from pydantic import Field, ValidationInfo, field_validator
2625
from typing_extensions import Literal
2726

2827
from conda_lock.common import ordered_union, relative_path
@@ -60,9 +59,10 @@ class BaseLockedDependency(StrictModel):
6059
def key(self) -> LockKey:
6160
return LockKey(self.manager, self.name, self.platform)
6261

63-
@validator("hash")
64-
def validate_hash(cls, v: HashModel, values: Dict[str, typing.Any]) -> HashModel:
65-
if (values["manager"] == "conda") and (v.md5 is None):
62+
@field_validator("hash")
63+
@classmethod
64+
def validate_hash(cls, v: HashModel, info: ValidationInfo) -> HashModel:
65+
if (info.data["manager"] == "conda") and (v.md5 is None):
6666
raise ValueError("conda package hashes must use MD5")
6767
return v
6868

@@ -217,7 +217,7 @@ class LockMeta(StrictModel):
217217
..., description="Hash of dependencies for each target platform"
218218
)
219219
channels: List[Channel] = Field(
220-
..., description="Channels used to resolve dependencies"
220+
..., description="Channels used to resolve dependencies", validate_default=True
221221
)
222222
platforms: List[str] = Field(..., description="Target platforms")
223223
sources: List[str] = Field(
@@ -282,7 +282,8 @@ def __or__(self, other: "LockMeta") -> "LockMeta":
282282
custom_metadata=new_custom_metadata,
283283
)
284284

285-
@validator("channels", pre=True, always=True)
285+
@field_validator("channels", mode="before")
286+
@classmethod
286287
def ensure_channels(cls, v: List[Union[str, Channel]]) -> List[Channel]:
287288
res: List[Channel] = []
288289
for e in v:
@@ -304,10 +305,12 @@ def dict_for_output(self) -> Dict[str, Any]:
304305
return {
305306
"version": Lockfile.version,
306307
"metadata": json.loads(
307-
self.metadata.json(by_alias=True, exclude_unset=True, exclude_none=True)
308+
self.metadata.model_dump_json(
309+
by_alias=True, exclude_unset=True, exclude_none=True
310+
)
308311
),
309312
"package": [
310-
package.dict(by_alias=True, exclude_unset=True, exclude_none=True)
313+
package.model_dump(by_alias=True, exclude_unset=True, exclude_none=True)
311314
for package in self.package
312315
],
313316
}

conda_lock/models/channel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,14 @@ class CondaUrl(BaseModel):
5757
raw_url: str
5858
env_var_url: str
5959

60-
token: Optional[str]
61-
token_env_var: Optional[str]
60+
token: Optional[str] = None
61+
token_env_var: Optional[str] = None
6262

63-
user: Optional[str]
64-
user_env_var: Optional[str]
63+
user: Optional[str] = None
64+
user_env_var: Optional[str] = None
6565

66-
password: Optional[str]
67-
password_env_var: Optional[str]
66+
password: Optional[str] = None
67+
password_env_var: Optional[str] = None
6868

6969
@classmethod
7070
def from_string(cls, value: str) -> "CondaUrl":

conda_lock/models/lock_spec.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from typing import Dict, List, Optional, Union
77

8-
from pydantic import BaseModel, Field, validator
8+
from pydantic import BaseModel, Field, field_validator
99
from typing_extensions import Literal
1010

1111
from conda_lock.models import StrictModel
@@ -21,7 +21,8 @@ class _BaseDependency(StrictModel):
2121
extras: List[str] = []
2222
markers: Optional[str] = None
2323

24-
@validator("extras")
24+
@field_validator("extras")
25+
@classmethod
2526
def sorted_extras(cls, v: List[str]) -> List[str]:
2627
return sorted(v)
2728

@@ -53,11 +54,11 @@ class Package(StrictModel):
5354

5455

5556
class PoetryMappedDependencySpec(StrictModel):
56-
url: Optional[str]
57+
url: Optional[str] = None
5758
manager: Literal["conda", "pip"]
5859
extras: List
59-
markers: Optional[str]
60-
poetry_version_spec: Optional[str]
60+
markers: Optional[str] = None
61+
poetry_version_spec: Optional[str] = None
6162

6263

6364
class LockSpecification(BaseModel):
@@ -84,16 +85,18 @@ def content_hash_for_platform(
8485
self, platform: str, virtual_package_repo: Optional[FakeRepoData]
8586
) -> str:
8687
data = {
87-
"channels": [c.json() for c in self.channels],
88+
"channels": [c.model_dump_json() for c in self.channels],
8889
"specs": [
89-
p.dict()
90+
p.model_dump()
9091
for p in sorted(
9192
self.dependencies[platform], key=lambda p: (p.manager, p.name)
9293
)
9394
],
9495
}
9596
if self.pip_repositories:
96-
data["pip_repositories"] = [repo.json() for repo in self.pip_repositories]
97+
data["pip_repositories"] = [
98+
repo.model_dump_json() for repo in self.pip_repositories
99+
]
97100
if virtual_package_repo is not None:
98101
vpr_data = virtual_package_repo.all_repodata
99102
data["virtual_package_hash"] = {
@@ -104,7 +107,8 @@ def content_hash_for_platform(
104107
env_spec = json.dumps(data, sort_keys=True)
105108
return hashlib.sha256(env_spec.encode("utf-8")).hexdigest()
106109

107-
@validator("channels", pre=True)
110+
@field_validator("channels", mode="before")
111+
@classmethod
108112
def validate_channels(cls, v: List[Union[Channel, str]]) -> List[Channel]:
109113
for i, e in enumerate(v):
110114
if isinstance(e, str):
@@ -114,7 +118,8 @@ def validate_channels(cls, v: List[Union[Channel, str]]) -> List[Channel]:
114118
raise ValueError("nodefaults channel is not allowed, ref #418")
115119
return typing.cast(List[Channel], v)
116120

117-
@validator("pip_repositories", pre=True)
121+
@field_validator("pip_repositories", mode="before")
122+
@classmethod
118123
def validate_pip_repositories(
119124
cls, value: List[Union[PipRepository, str]]
120125
) -> List[PipRepository]:

conda_lock/virtual_package.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from types import TracebackType
99
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type
1010

11-
from pydantic import BaseModel, Field, validator
11+
from pydantic import BaseModel, ConfigDict, Field, field_validator
1212

1313
from conda_lock.interfaces.vendored_conda import MatchSpec
1414
from conda_lock.models.channel import Channel
@@ -23,8 +23,7 @@
2323
class FakePackage(BaseModel):
2424
"""A minimal representation of the required metadata for a conda package"""
2525

26-
class Config:
27-
frozen = True
26+
model_config = ConfigDict(frozen=True)
2827

2928
name: str
3029
version: str = "1.0"
@@ -36,7 +35,7 @@ class Config:
3635
package_type: Optional[str] = "virtual_system"
3736

3837
def to_repodata_entry(self) -> Tuple[str, Dict[str, Any]]:
39-
out = self.dict()
38+
out = self.model_dump()
4039
if self.build_string:
4140
build = f"{self.build_string}_{self.build_number}"
4241
else:
@@ -236,7 +235,8 @@ def default_virtual_package_repodata(cuda_version: str = "11.4") -> FakeRepoData
236235
class VirtualPackageSpecSubdir(BaseModel):
237236
packages: Dict[str, str]
238237

239-
@validator("packages")
238+
@field_validator("packages")
239+
@classmethod
240240
def validate_packages(cls, v: Dict[str, str]) -> Dict[str, str]:
241241
for package_name in v:
242242
if not package_name.startswith("__"):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ dependencies = [
3333
"ensureconda >=1.4.4",
3434
"gitpython >=3.1.30",
3535
"jinja2",
36-
"pydantic >=1.10",
36+
"pydantic >=2",
3737
"pyyaml >= 5.1",
3838
# constraint on version comes from poetry
3939
"requests >=2.26,<3.0",

tests/test_conda_lock.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1725,7 +1725,9 @@ def test_aggregate_lock_specs():
17251725
],
17261726
sources=[],
17271727
)
1728-
assert actual.dict(exclude={"sources"}) == expected.dict(exclude={"sources"})
1728+
assert actual.model_dump(exclude={"sources"}) == expected.model_dump(
1729+
exclude={"sources"}
1730+
)
17291731
assert actual.content_hash(None) == expected.content_hash(None)
17301732

17311733

0 commit comments

Comments
 (0)