Skip to content

Commit 8d63801

Browse files
committed
migrate to pydantic v2
1 parent ecfc1c7 commit 8d63801

File tree

8 files changed

+144
-73
lines changed

8 files changed

+144
-73
lines changed

catalax/model/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ class CatalaxBase(BaseModel):
66
class Config:
77
use_enum_values = True
88
arbitrary_types_allowed = True
9-
allow_mutation = True
109
validate_assignment = True
1110

12-
__repr_fields__: List[str] = PrivateAttr(default=["__all__"])
11+
_repr_fields: List[str] = PrivateAttr(default=["__all__"])

catalax/model/model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,6 @@ class Model(CatalaxBase):
5050

5151
class Config:
5252
arbitrary_types_allowed = True
53-
fields = {
54-
"term": {"exclude": True},
55-
}
5653

5754
name: str
5855
odes: Dict[str, ODE] = Field(default_factory=DottedDict)
@@ -116,7 +113,7 @@ def add_ode(
116113
equation=equation, species=self.species[species], observable=observable
117114
)
118115

119-
self.odes[species].__model__ = self
116+
self.odes[species]._model = self
120117

121118
def add_species(self, species_string: str = "", **species_map):
122119
"""Adds a single or multiple species to the model, which can later be used in ODEs.
@@ -164,7 +161,6 @@ def add_species(self, species_string: str = "", **species_map):
164161

165162
# Make sure the symbol is valid
166163
check_symbol(symbol)
167-
168164
self.species[symbol] = Species(name=name, symbol=Symbol(symbol))
169165

170166
@staticmethod

catalax/model/ode.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,13 @@
1414
class ODE(CatalaxBase):
1515
class Config:
1616
arbitrary_types_allowed = True
17-
fields = {
18-
"parameters": {"exclude": True},
19-
"species": {"exclude": True},
20-
}
2117

2218
species: Species
2319
equation: Expr
2420
observable: bool = True
2521
parameters: Dict[Union[str, Expr], Parameter] = Field(default_factory=DottedDict)
2622

27-
__model__: Optional["Model"] = PrivateAttr(default=None) # type: ignore
23+
_model: Optional["Model"] = PrivateAttr(default=None) # type: ignore
2824

2925
@validator("equation", pre=True)
3026
def converts_ode_to_sympy(cls, value):
@@ -37,7 +33,7 @@ def __setattr__(self, name, value):
3733

3834
super().__setattr__(name, value)
3935

40-
if name == "__model__":
36+
if name == "_model":
4137
self.add_parameters_to_model()
4238

4339
def add_parameters_to_model(self):
@@ -48,20 +44,20 @@ def add_parameters_to_model(self):
4844
done upon addition of the model due to no given knowledge of the species.
4945
"""
5046

51-
if self.__model__ is None:
47+
if self._model is None:
5248
return None
5349

5450
for symbol in self.equation.free_symbols:
55-
if str(symbol) in self.__model__.species or str(symbol) == "t":
51+
if str(symbol) in self._model.species or str(symbol) == "t":
5652
# Skip species and time symbol
5753
continue
58-
elif parameter_exists(str(symbol), self.__model__.parameters):
54+
elif parameter_exists(str(symbol), self._model.parameters):
5955
# Assign parameter if it is already present in the model
60-
self.parameters[str(symbol)] = self.__model__.parameters[str(symbol)]
56+
self.parameters[str(symbol)] = self._model.parameters[str(symbol)]
6157
continue
6258

6359
# Create a new one and add it to the model and ODE
6460
parameter = Parameter(name=str(symbol), symbol=symbol) # type: ignore
6561

6662
self.parameters[str(symbol)] = parameter
67-
self.__model__.parameters[str(symbol)] = parameter
63+
self._model.parameters[str(symbol)] = parameter

catalax/model/parameter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class Parameter(CatalaxBase):
2525
prior: Any = None # TODO: Fix this typing
2626
_prior_str_: Optional[str] = None
2727

28-
@root_validator()
28+
@root_validator(skip_on_failure=True)
2929
def _assign_prior_string(cls, values):
3030
if isinstance(values["prior"], tuple):
3131
prior, prior_str = values["prior"]
@@ -34,7 +34,7 @@ def _assign_prior_string(cls, values):
3434

3535
return values
3636

37-
__repr_fields__: List[str] = PrivateAttr(
37+
_repr_fields: List[str] = PrivateAttr(
3838
default={
3939
"name": "name",
4040
"symbol": "symbol",

catalax/model/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def __repr__(self):
3737

3838
@staticmethod
3939
def _fields_to_print(cls_: CatalaxBase) -> Dict[str, str]:
40-
if len(cls_.__repr_fields__) == 1 and cls_.__repr_fields__[0] == "__all__":
40+
if len(cls_._repr_fields) == 1 and cls_._repr_fields[0] == "__all__":
4141
return {key: key for key in cls_.__dict__}
4242

43-
return cls_.__repr_fields__
43+
return cls_._repr_fields
4444

4545

4646
def odeprint(y, expr):
@@ -77,16 +77,16 @@ def check_symbol(symbol: str) -> None:
7777
"""Checks whether the given symbol is a valid symbol"""
7878

7979
ERROR_MESSAGE = f"""Symbol '{symbol}' is not a valid symbol. The following rules apply:
80-
80+
8181
(1) The first character must be a letter
8282
(2) The remaining characters can be letters and numbers
8383
(3) The symbol cannot end with an underscore
8484
(4) The symbol can contain at most one underscore followed by letters and numbers
85-
85+
8686
These are valid symbols:
87-
87+
8888
k1, k_12, k_max, k_max1
89-
89+
9090
"""
9191

9292
# Convert to string to use string methods

catalax/tools/simulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class Config:
5757
parameters: List[str]
5858
stoich_mat: jax.Array
5959
dt0: float = 0.1
60-
solver: AbstractSolver = Tsit5
60+
solver: Any = Tsit5
6161
rtol: float = 1e-5
6262
atol: float = 1e-5
6363
max_steps: int = 64**4

0 commit comments

Comments
 (0)