1- import inspect
1+ import importlib
22import json
33import logging
44from pathlib import Path
5- from types import UnionType
6- from typing import Any , TypeAlias , Union , get_args , get_origin
5+ from typing import Any , TypeAlias
76
87import joblib
98import numpy as np
1312
1413from autointent import Embedder , Ranker , VectorIndex
1514from autointent .configs import CrossEncoderConfig , EmbedderConfig
16- from autointent .context ._utils import NumpyEncoder
1715from autointent .schemas import TagsList
1816
1917ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
@@ -36,11 +34,12 @@ class Dumper:
3634 pydantic_models : str = "pydantic"
3735
3836 @staticmethod
39- def make_subdirectories (path : Path ) -> None :
37+ def make_subdirectories (path : Path , exists_ok : bool = False ) -> None :
4038 """Make subdirectories for dumping.
4139
4240 Args:
4341 path: Path to make subdirectories in
42+ exists_ok: If True, do not raise an error if the directory already exists
4443 """
4544 subdirectories = [
4645 path / Dumper .tags ,
@@ -51,23 +50,27 @@ def make_subdirectories(path: Path) -> None:
5150 path / Dumper .pydantic_models ,
5251 ]
5352 for subdir in subdirectories :
54- subdir .mkdir (parents = True , exist_ok = True )
53+ subdir .mkdir (parents = True , exist_ok = exists_ok )
5554
5655 @staticmethod
57- def dump (obj : Any , path : Path ) -> None : # noqa: ANN401, C901
56+ def dump (obj : Any , path : Path , exists_ok : bool = False , exclude : list [ type [ Any ]] | None = None ) -> None : # noqa: ANN401, C901
5857 """Dump modules attributes to filestystem.
5958
6059 Args:
6160 obj: Object to dump
6261 path: Path to dump to
62+ exists_ok: If True, do not raise an error if the directory already exists
63+ exclude: List of types to exclude from dumping
6364 """
6465 attrs : dict [str , ModuleAttributes ] = vars (obj )
6566 simple_attrs = {}
6667 arrays : dict [str , npt .NDArray [Any ]] = {}
6768
68- Dumper .make_subdirectories (path )
69+ Dumper .make_subdirectories (path , exists_ok )
6970
7071 for key , val in attrs .items ():
72+ if exclude and isinstance (val , tuple (exclude )):
73+ continue
7174 if isinstance (val , TagsList ):
7275 val .dump (path / Dumper .tags / key )
7376 elif isinstance (val , ModuleSimpleAttributes ):
@@ -84,9 +87,13 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
8487 val .save (str (path / Dumper .cross_encoders / key ))
8588 elif isinstance (val , BaseModel ):
8689 try :
87- pydantic_path = path / Dumper .pydantic_models / f"{ key } .json"
88- with pydantic_path .open ("w" , encoding = "utf-8" ) as file :
89- json .dump (val .model_dump (), file , ensure_ascii = False , indent = 4 , cls = NumpyEncoder )
90+ class_info = {"name" : val .__class__ .__name__ , "module" : val .__class__ .__module__ }
91+ pydantic_path = path / Dumper .pydantic_models / key
92+ pydantic_path .mkdir (parents = True , exist_ok = exists_ok )
93+ with (pydantic_path / "class_info.json" ).open ("w" , encoding = "utf-8" ) as file :
94+ json .dump (class_info , file , ensure_ascii = False , indent = 4 )
95+ with (pydantic_path / "model_dump.json" ).open ("w" , encoding = "utf-8" ) as file :
96+ json .dump (val .model_dump (), file , ensure_ascii = False , indent = 4 )
9097 except Exception as e :
9198 msg = f"Error dumping pydantic model { key } : { e } "
9299 logging .exception (msg )
@@ -100,7 +107,7 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
100107 np .savez (path / Dumper .arrays , allow_pickle = False , ** arrays )
101108
102109 @staticmethod
103- def load ( # noqa: PLR0912, C901 , PLR0915
110+ def load ( # noqa: C901, PLR0912 , PLR0915
104111 obj : Any , # noqa: ANN401
105112 path : Path ,
106113 embedder_config : EmbedderConfig | None = None ,
@@ -139,42 +146,34 @@ def load( # noqa: PLR0912, C901, PLR0915
139146 for cross_encoder_dump in child .iterdir ()
140147 }
141148 elif child .name == Dumper .pydantic_models :
142- for model_file in child .iterdir ():
143- with model_file .open ("r" , encoding = "utf-8" ) as file :
144- content = json .load (file )
145- variable_name = model_file .stem
146-
147- # First try to get the type annotation from the class annotations.
148- model_type = obj .__class__ .__annotations__ .get (variable_name )
149-
150- # Fallback: inspect __init__ signature if not found in class-level annotations.
151- if model_type is None :
152- sig = inspect .signature (obj .__init__ )
153- if variable_name in sig .parameters :
154- model_type = sig .parameters [variable_name ].annotation
155-
156- if model_type is None :
157- msg = f"No type annotation found for { variable_name } "
158- logger .error (msg )
159- continue
160-
161- # If the annotation is a Union, extract the pydantic model type.
162- if get_origin (model_type ) in (UnionType , Union ):
163- for arg in get_args (model_type ):
164- if isinstance (arg , type ) and issubclass (arg , BaseModel ):
165- model_type = arg
166- break
167- else :
168- msg = f"No pydantic type found in Union for { variable_name } "
169- logger .error (msg )
149+ for model_dir in child .iterdir ():
150+ try :
151+ with (model_dir / "model_dump.json" ).open ("r" , encoding = "utf-8" ) as file :
152+ content = json .load (file )
153+
154+ variable_name = model_dir .name
155+
156+ with (model_dir / "class_info.json" ).open ("r" , encoding = "utf-8" ) as file :
157+ class_info = json .load (file )
158+
159+ try :
160+ model_type = importlib .import_module (class_info ["module" ])
161+ model_type = getattr (model_type , class_info ["name" ])
162+ except (ImportError , AttributeError ) as e :
163+ msg = f"Failed to import model type for { variable_name } : { e } "
164+ logger .exception (msg )
170165 continue
171166
172- if not (isinstance (model_type , type ) and issubclass (model_type , BaseModel )):
173- msg = f"Type for { variable_name } is not a pydantic model: { model_type } "
174- logger .error (msg )
167+ try :
168+ pydantic_models [variable_name ] = model_type .model_validate (content )
169+ except Exception as e :
170+ msg = f"Failed to reconstruct Pydantic model { variable_name } : { e } "
171+ logger .exception (msg )
172+ continue
173+ except Exception as e :
174+ msg = f"Error loading Pydantic model from { model_dir } : { e } "
175+ logger .exception (msg )
175176 continue
176-
177- pydantic_models [variable_name ] = model_type (** content )
178177 else :
179178 msg = f"Found unexpected child { child } "
180179 logger .error (msg )
0 commit comments