Skip to content

Commit 0694ebd

Browse files
authored
load pydantic models (#143)
1 parent cd6167a commit 0694ebd

File tree

1 file changed

+66
-5
lines changed

1 file changed

+66
-5
lines changed

autointent/_dump_tools.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1+
import inspect
12
import json
23
import logging
34
from pathlib import Path
4-
from typing import Any, TypeAlias
5+
from types import UnionType
6+
from typing import Any, TypeAlias, Union, get_args, get_origin
57

68
import joblib
79
import numpy as np
810
import numpy.typing as npt
11+
from pydantic import BaseModel
912
from sklearn.base import BaseEstimator
1013

1114
from autointent import Embedder, Ranker, VectorIndex
12-
from autointent.schemas import TagsList
15+
from autointent.schemas import CrossEncoderConfig, EmbedderConfig, TagsList
1316

1417
ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
1518

@@ -28,6 +31,7 @@ class Dumper:
2831
indexes = "vector_indexes"
2932
estimators = "estimators"
3033
cross_encoders = "cross_encoders"
34+
pydantic_models: str = "pydantic"
3135

3236
@staticmethod
3337
def make_subdirectories(path: Path) -> None:
@@ -37,12 +41,13 @@ def make_subdirectories(path: Path) -> None:
3741
path / Dumper.indexes,
3842
path / Dumper.estimators,
3943
path / Dumper.cross_encoders,
44+
path / Dumper.pydantic_models,
4045
]
4146
for subdir in subdirectories:
4247
subdir.mkdir(parents=True, exist_ok=True)
4348

4449
@staticmethod
45-
def dump(obj: Any, path: Path) -> None: # noqa: ANN401
50+
def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
4651
"""Dump modules attributes to filestystem."""
4752
attrs: dict[str, ModuleAttributes] = vars(obj)
4853
simple_attrs = {}
@@ -65,6 +70,14 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401
6570
joblib.dump(val, path / Dumper.estimators / key)
6671
elif isinstance(val, Ranker):
6772
val.save(str(path / Dumper.cross_encoders / key))
73+
elif isinstance(val, CrossEncoderConfig | EmbedderConfig):
74+
try:
75+
pydantic_path = path / Dumper.pydantic_models / f"{key}.json"
76+
with pydantic_path.open("w", encoding="utf-8") as file:
77+
json.dump(val.model_dump(), file, ensure_ascii=False, indent=4)
78+
except Exception as e:
79+
msg = f"Error dumping pydantic model {key}: {e}"
80+
logging.exception(msg)
6881
else:
6982
msg = f"Attribute {key} of type {type(val)} cannot be dumped to file system."
7083
logger.error(msg)
@@ -75,8 +88,17 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401
7588
np.savez(path / Dumper.arrays, allow_pickle=False, **arrays)
7689

7790
@staticmethod
78-
def load(obj: Any, path: Path) -> None: # noqa: ANN401
91+
def load(obj: Any, path: Path) -> None: # noqa: ANN401, PLR0912, C901, PLR0915
7992
"""Load attributes from file system."""
93+
tags: dict[str, Any] = {}
94+
simple_attrs: dict[str, Any] = {}
95+
arrays: dict[str, Any] = {}
96+
embedders: dict[str, Any] = {}
97+
indexes: dict[str, Any] = {}
98+
estimators: dict[str, Any] = {}
99+
cross_encoders: dict[str, Any] = {}
100+
pydantic_models: dict[str, Any] = {}
101+
80102
for child in path.iterdir():
81103
if child.name == Dumper.tags:
82104
tags = {tags_dump.name: TagsList.load(tags_dump) for tags_dump in child.iterdir()}
@@ -96,7 +118,46 @@ def load(obj: Any, path: Path) -> None: # noqa: ANN401
96118
cross_encoders = {
97119
cross_encoder_dump.name: Ranker.load(cross_encoder_dump) for cross_encoder_dump in child.iterdir()
98120
}
121+
elif child.name == Dumper.pydantic_models:
122+
for model_file in child.iterdir():
123+
with model_file.open("r", encoding="utf-8") as file:
124+
content = json.load(file)
125+
variable_name = model_file.stem
126+
127+
# First try to get the type annotation from the class annotations.
128+
model_type = obj.__class__.__annotations__.get(variable_name)
129+
130+
# Fallback: inspect __init__ signature if not found in class-level annotations.
131+
if model_type is None:
132+
sig = inspect.signature(obj.__init__)
133+
if variable_name in sig.parameters:
134+
model_type = sig.parameters[variable_name].annotation
135+
136+
if model_type is None:
137+
msg = f"No type annotation found for {variable_name}"
138+
logger.error(msg)
139+
continue
140+
141+
# If the annotation is a Union, extract the pydantic model type.
142+
if get_origin(model_type) in (UnionType, Union):
143+
for arg in get_args(model_type):
144+
if isinstance(arg, type) and issubclass(arg, BaseModel):
145+
model_type = arg
146+
break
147+
else:
148+
msg = f"No pydantic type found in Union for {variable_name}"
149+
logger.error(msg)
150+
continue
151+
152+
if not (isinstance(model_type, type) and issubclass(model_type, BaseModel)):
153+
msg = f"Type for {variable_name} is not a pydantic model: {model_type}"
154+
logger.error(msg)
155+
continue
156+
157+
pydantic_models[variable_name] = model_type(**content)
99158
else:
100159
msg = f"Found unexpected child {child}"
101160
logger.error(msg)
102-
obj.__dict__.update(tags | simple_attrs | arrays | embedders | indexes | estimators | cross_encoders)
161+
obj.__dict__.update(
162+
tags | simple_attrs | arrays | embedders | indexes | estimators | cross_encoders | pydantic_models
163+
)

0 commit comments

Comments
 (0)