Skip to content

Commit d54ebf2

Browse files
authored
Merge pull request #348 from loriab/csse_pyd2_510_pt2_more
Csse pyd2 510 Part 2
2 parents c219aa2 + a1f7dcc commit d54ebf2

File tree

9 files changed

+270
-72
lines changed

9 files changed

+270
-72
lines changed

docs/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Enhancements
3838
* The ``models.v2`` have had their `schema_version` bumped for ``BasisSet``, ``AtomicInput``, ``OptimizationInput`` (implicit for ``AtomicResult`` and ``OptimizationResult``), ``TorsionDriveInput`` , and ``TorsionDriveResult``.
3939
* The ``models.v2`` ``AtomicResultProperties`` has been given a ``schema_name`` and ``schema_version`` (2) for the first time.
4040
* Note that ``models.v2`` ``QCInputSpecification`` and ``OptimizationSpecification`` have *not* had schema_version bumped.
41+
* All of ``Datum``, ``DFTFunctional``, and ``CPUInfo`` models, none of which are mixed with QCSchema models, are translated to Pydantic v2 API syntax.
4142

4243
Bug Fixes
4344
+++++++++

qcelemental/datum.py

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,61 @@
33
"""
44

55
from decimal import Decimal
6-
from typing import Any, Dict, Optional
6+
from typing import Any, Dict, Optional, Union
77

88
import numpy as np
9+
from pydantic import (
10+
BaseModel,
11+
ConfigDict,
12+
SerializationInfo,
13+
SerializerFunctionWrapHandler,
14+
WrapSerializer,
15+
field_validator,
16+
model_serializer,
17+
)
18+
from typing_extensions import Annotated
19+
20+
21+
def reduce_complex(data):
22+
# Reduce Complex
23+
if isinstance(data, complex):
24+
return [data.real, data.imag]
25+
# Fallback
26+
return data
27+
28+
29+
def keep_decimal_cast_ndarray_complex(
30+
v: Any, nxt: SerializerFunctionWrapHandler, info: SerializationInfo
31+
) -> Union[list, Decimal, float]:
32+
"""
33+
Ensure Decimal types are preserved on the way out
34+
35+
This arose because Decimal was serialized to string and "dump" is equal to "serialize" in v2 pydantic
36+
https://docs.pydantic.dev/latest/migration/#changes-to-json-schema-generation
37+
38+
39+
This also checks against NumPy Arrays and complex numbers in the instance of being in JSON mode
40+
"""
41+
if isinstance(v, Decimal):
42+
return v
43+
if info.mode == "json":
44+
if isinstance(v, complex):
45+
return nxt(reduce_complex(v))
46+
if isinstance(v, np.ndarray):
47+
# Handle NDArray and complex NDArray
48+
flat_list = v.flatten().tolist()
49+
reduced_list = list(map(reduce_complex, flat_list))
50+
return nxt(reduced_list)
51+
try:
52+
# Cast NumPy scalar data types to native Python data type
53+
v = v.item()
54+
except (AttributeError, ValueError):
55+
pass
56+
return nxt(v)
57+
958

10-
try:
11-
from pydantic.v1 import BaseModel, validator
12-
except ImportError: # Will also trap ModuleNotFoundError
13-
from pydantic import BaseModel, validator
59+
# Only 1 serializer is allowed. You can't chain wrap serializers.
60+
AnyArrayComplex = Annotated[Any, WrapSerializer(keep_decimal_cast_ndarray_complex)]
1461

1562

1663
class Datum(BaseModel):
@@ -38,15 +85,15 @@ class Datum(BaseModel):
3885
numeric: bool
3986
label: str
4087
units: str
41-
data: Any
88+
data: AnyArrayComplex
4289
comment: str = ""
4390
doi: Optional[str] = None
4491
glossary: str = ""
4592

46-
class Config:
47-
extra = "forbid"
48-
allow_mutation = False
49-
json_encoders = {np.ndarray: lambda v: v.flatten().tolist(), complex: lambda v: (v.real, v.imag)}
93+
model_config = ConfigDict(
94+
extra="forbid",
95+
frozen=True,
96+
)
5097

5198
def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None, numeric=True):
5299
kwargs = {"label": label, "units": units, "data": data, "numeric": numeric}
@@ -59,20 +106,21 @@ def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None,
59106

60107
super().__init__(**kwargs)
61108

62-
@validator("data")
63-
def must_be_numerical(cls, v, values, **kwargs):
109+
@field_validator("data")
110+
@classmethod
111+
def must_be_numerical(cls, v, info):
64112
try:
65113
1.0 * v
66114
except TypeError:
67115
try:
68116
Decimal("1.0") * v
69117
except TypeError:
70-
if values["numeric"]:
118+
if info.data["numeric"]:
71119
raise ValueError(f"Datum data should be float, Decimal, or np.ndarray, not {type(v)}.")
72120
else:
73-
values["numeric"] = True
121+
info.data["numeric"] = True
74122
else:
75-
values["numeric"] = True
123+
info.data["numeric"] = True
76124

77125
return v
78126

@@ -90,8 +138,35 @@ def __str__(self, label=""):
90138
text.append("-" * width)
91139
return "\n".join(text)
92140

141+
@model_serializer(mode="wrap")
142+
def _serialize_model(self, handler) -> Dict[str, Any]:
143+
"""
144+
Customize the serialization output. Does duplicate with some code in model_dump, but handles the case of nested
145+
models and any model config options.
146+
147+
Encoding is handled at the `model_dump` level and not here as that should happen only after EVERYTHING has been
148+
dumped/de-pydantic-ized.
149+
"""
150+
151+
# Get the default return, let the model_dump handle kwarg
152+
default_result = handler(self)
153+
# Exclude unset always
154+
output_dict = {key: value for key, value in default_result.items() if key in self.model_fields_set}
155+
return output_dict
156+
93157
def dict(self, *args, **kwargs):
94-
return super().dict(*args, **{**kwargs, **{"exclude_unset": True}})
158+
"""
159+
Passthrough to model_dump without deprecation warning
160+
exclude_unset is forced through the model_serializer
161+
"""
162+
return super().model_dump(*args, **kwargs)
163+
164+
def json(self, *args, **kwargs):
165+
"""
166+
Passthrough to model_dump_sjon without deprecation warning
167+
exclude_unset is forced through the model_serializer
168+
"""
169+
return super().model_dump_json(*args, **kwargs)
95170

96171
def to_units(self, units=None):
97172
from .physical_constants import constants

qcelemental/info/cpu_info.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
from functools import lru_cache
99
from typing import List, Optional
1010

11-
from pydantic.v1 import Field
11+
from pydantic import BeforeValidator, Field
12+
from typing_extensions import Annotated
1213

13-
from ..models import ProtoModel
14+
from ..models.v2 import ProtoModel
15+
16+
# ProcessorInfo models don't become parts of QCSchema models afaik, so pure pydantic v2 API
1417

1518

1619
class VendorEnum(str, Enum):
@@ -22,6 +25,13 @@ class VendorEnum(str, Enum):
2225
arm = "arm"
2326

2427

28+
def stringify(v) -> str:
29+
return str(v)
30+
31+
32+
Stringify = Annotated[str, BeforeValidator(stringify)]
33+
34+
2535
class InstructionSetEnum(int, Enum):
2636
"""Allowed instruction sets for CPUs in an ordinal enum."""
2737

@@ -37,13 +47,13 @@ class ProcessorInfo(ProtoModel):
3747
ncores: int = Field(..., description="The number of physical cores on the chip.")
3848
nthreads: Optional[int] = Field(..., description="The maximum number of concurrent threads.")
3949
base_clock: float = Field(..., description="The base clock frequency (GHz).")
40-
boost_clock: Optional[float] = Field(..., description="The boost clock frequency (GHz).")
41-
model: str = Field(..., description="The model number of the chip.")
50+
boost_clock: Optional[float] = Field(None, description="The boost clock frequency (GHz).")
51+
model: Stringify = Field(..., description="The model number of the chip.")
4252
family: str = Field(..., description="The family of the chip.")
43-
launch_date: Optional[int] = Field(..., description="The launch year of the chip.")
53+
launch_date: Optional[int] = Field(None, description="The launch year of the chip.")
4454
target_use: str = Field(..., description="Target use case (Desktop, Server, etc).")
4555
vendor: VendorEnum = Field(..., description="The vendor the chip is produced by.")
46-
microarchitecture: Optional[str] = Field(..., description="The microarchitecture the chip follows.")
56+
microarchitecture: Optional[str] = Field(None, description="The microarchitecture the chip follows.")
4757
instructions: InstructionSetEnum = Field(..., description="The maximum vectorized instruction set available.")
4858
type: str = Field(..., description="The type of chip (cpu, gpu, etc).")
4959

qcelemental/info/dft_info.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44

55
from typing import Dict
66

7-
from pydantic.v1 import Field
7+
from pydantic import Field
8+
from typing_extensions import Annotated
89

9-
from ..models import ProtoModel
10+
from ..models.v2 import ProtoModel
11+
12+
# DFTFunctional models don't become parts of QCSchema models afaik, so pure pydantic v2 API
1013

1114

1215
class DFTFunctionalInfo(ProtoModel):
@@ -68,4 +71,4 @@ def get(name: str) -> DFTFunctionalInfo:
6871
name = name.replace(x, "")
6972
break
7073

71-
return dftfunctionalinfo.functionals[name].copy()
74+
return dftfunctionalinfo.functionals[name].model_copy()

qcelemental/testing.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union
66

77
import numpy as np
8-
9-
try:
10-
from pydantic.v1 import BaseModel
11-
except ImportError: # Will also trap ModuleNotFoundError
12-
from pydantic import BaseModel
8+
import pydantic
139

1410
if TYPE_CHECKING:
1511
from qcelemental.models import ProtoModel # TODO: recheck if .v1 needed
@@ -313,10 +309,16 @@ def _compare_recursive(expected, computed, atol, rtol, _prefix=False, equal_phas
313309
prefix = name + "."
314310

315311
# Initial conversions if required
316-
if isinstance(expected, BaseModel):
312+
if isinstance(expected, pydantic.BaseModel):
313+
expected = expected.model_dump()
314+
315+
if isinstance(computed, pydantic.BaseModel):
316+
computed = computed.model_dump()
317+
318+
if isinstance(expected, pydantic.v1.BaseModel):
317319
expected = expected.dict()
318320

319-
if isinstance(computed, BaseModel):
321+
if isinstance(computed, pydantic.v1.BaseModel):
320322
computed = computed.dict()
321323

322324
if isinstance(expected, (str, int, bool, complex)):
@@ -381,8 +383,8 @@ def _compare_recursive(expected, computed, atol, rtol, _prefix=False, equal_phas
381383

382384

383385
def compare_recursive(
384-
expected: Union[Dict, BaseModel, "ProtoModel"], # type: ignore
385-
computed: Union[Dict, BaseModel, "ProtoModel"], # type: ignore
386+
expected: Union[Dict, pydantic.BaseModel, pydantic.v1.BaseModel, "ProtoModel"], # type: ignore
387+
computed: Union[Dict, pydantic.BaseModel, pydantic.v1.BaseModel, "ProtoModel"], # type: ignore
386388
label: str = None,
387389
*,
388390
atol: float = 1.0e-6,

qcelemental/tests/test_datum.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
from decimal import Decimal
22

33
import numpy as np
4-
5-
try:
6-
import pydantic.v1 as pydantic
7-
except ImportError: # Will also trap ModuleNotFoundError
8-
import pydantic
4+
import pydantic
95
import pytest
106

117
import qcelemental as qcel
@@ -46,10 +42,10 @@ def test_creation_nonnum(dataset):
4642

4743

4844
def test_creation_error():
49-
with pytest.raises(pydantic.ValidationError):
45+
with pytest.raises(pydantic.ValidationError) as e:
5046
qcel.Datum("ze lbl", "ze unit", "ze data")
5147

52-
# assert 'Datum data should be float' in str(e)
48+
assert "Datum data should be float" in str(e.value)
5349

5450

5551
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)