1+ import inspect
12import json
23import logging
34from pathlib import Path
4- from typing import Any , TypeAlias
5+ from types import UnionType
6+ from typing import Any , TypeAlias , Union , get_args , get_origin
57
68import joblib
79import numpy as np
810import numpy .typing as npt
11+ from pydantic import BaseModel
912from sklearn .base import BaseEstimator
1013
1114from autointent import Embedder , Ranker , VectorIndex
12- from autointent .schemas import TagsList
15+ from autointent .schemas import CrossEncoderConfig , EmbedderConfig , TagsList
1316
1417ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
1518
@@ -28,6 +31,7 @@ class Dumper:
2831 indexes = "vector_indexes"
2932 estimators = "estimators"
3033 cross_encoders = "cross_encoders"
34+ pydantic_models : str = "pydantic"
3135
3236 @staticmethod
3337 def make_subdirectories (path : Path ) -> None :
@@ -37,12 +41,13 @@ def make_subdirectories(path: Path) -> None:
3741 path / Dumper .indexes ,
3842 path / Dumper .estimators ,
3943 path / Dumper .cross_encoders ,
44+ path / Dumper .pydantic_models ,
4045 ]
4146 for subdir in subdirectories :
4247 subdir .mkdir (parents = True , exist_ok = True )
4348
4449 @staticmethod
45- def dump (obj : Any , path : Path ) -> None : # noqa: ANN401
50+ def dump (obj : Any , path : Path ) -> None : # noqa: ANN401, C901
4651 """Dump modules attributes to filestystem."""
4752 attrs : dict [str , ModuleAttributes ] = vars (obj )
4853 simple_attrs = {}
@@ -65,6 +70,14 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401
6570 joblib .dump (val , path / Dumper .estimators / key )
6671 elif isinstance (val , Ranker ):
6772 val .save (str (path / Dumper .cross_encoders / key ))
73+ elif isinstance (val , CrossEncoderConfig | EmbedderConfig ):
74+ try :
75+ pydantic_path = path / Dumper .pydantic_models / f"{ key } .json"
76+ with pydantic_path .open ("w" , encoding = "utf-8" ) as file :
77+ json .dump (val .model_dump (), file , ensure_ascii = False , indent = 4 )
78+ except Exception as e :
79+ msg = f"Error dumping pydantic model { key } : { e } "
80+ logging .exception (msg )
6881 else :
6982 msg = f"Attribute { key } of type { type (val )} cannot be dumped to file system."
7083 logger .error (msg )
@@ -75,8 +88,17 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401
7588 np .savez (path / Dumper .arrays , allow_pickle = False , ** arrays )
7689
7790 @staticmethod
78- def load (obj : Any , path : Path ) -> None : # noqa: ANN401
91+ def load (obj : Any , path : Path ) -> None : # noqa: ANN401, PLR0912, C901, PLR0915
7992 """Load attributes from file system."""
93+ tags : dict [str , Any ] = {}
94+ simple_attrs : dict [str , Any ] = {}
95+ arrays : dict [str , Any ] = {}
96+ embedders : dict [str , Any ] = {}
97+ indexes : dict [str , Any ] = {}
98+ estimators : dict [str , Any ] = {}
99+ cross_encoders : dict [str , Any ] = {}
100+ pydantic_models : dict [str , Any ] = {}
101+
80102 for child in path .iterdir ():
81103 if child .name == Dumper .tags :
82104 tags = {tags_dump .name : TagsList .load (tags_dump ) for tags_dump in child .iterdir ()}
@@ -96,7 +118,46 @@ def load(obj: Any, path: Path) -> None: # noqa: ANN401
96118 cross_encoders = {
97119 cross_encoder_dump .name : Ranker .load (cross_encoder_dump ) for cross_encoder_dump in child .iterdir ()
98120 }
121+ elif child .name == Dumper .pydantic_models :
122+ for model_file in child .iterdir ():
123+ with model_file .open ("r" , encoding = "utf-8" ) as file :
124+ content = json .load (file )
125+ variable_name = model_file .stem
126+
127+ # First try to get the type annotation from the class annotations.
128+ model_type = obj .__class__ .__annotations__ .get (variable_name )
129+
130+ # Fallback: inspect __init__ signature if not found in class-level annotations.
131+ if model_type is None :
132+ sig = inspect .signature (obj .__init__ )
133+ if variable_name in sig .parameters :
134+ model_type = sig .parameters [variable_name ].annotation
135+
136+ if model_type is None :
137+ msg = f"No type annotation found for { variable_name } "
138+ logger .error (msg )
139+ continue
140+
141+ # If the annotation is a Union, extract the pydantic model type.
142+ if get_origin (model_type ) in (UnionType , Union ):
143+ for arg in get_args (model_type ):
144+ if isinstance (arg , type ) and issubclass (arg , BaseModel ):
145+ model_type = arg
146+ break
147+ else :
148+ msg = f"No pydantic type found in Union for { variable_name } "
149+ logger .error (msg )
150+ continue
151+
152+ if not (isinstance (model_type , type ) and issubclass (model_type , BaseModel )):
153+ msg = f"Type for { variable_name } is not a pydantic model: { model_type } "
154+ logger .error (msg )
155+ continue
156+
157+ pydantic_models [variable_name ] = model_type (** content )
99158 else :
100159 msg = f"Found unexpected child { child } "
101160 logger .error (msg )
102- obj .__dict__ .update (tags | simple_attrs | arrays | embedders | indexes | estimators | cross_encoders )
161+ obj .__dict__ .update (
162+ tags | simple_attrs | arrays | embedders | indexes | estimators | cross_encoders | pydantic_models
163+ )
0 commit comments