11import inspect
22import json
33import logging
4+ import types
45from pathlib import Path
56from types import UnionType
67from typing import Any , TypeAlias , Union , get_args , get_origin
1011import numpy .typing as npt
1112from pydantic import BaseModel
1213from sklearn .base import BaseEstimator
14+ from transformers import (
15+ AutoModelForSequenceClassification ,
16+ AutoTokenizer ,
17+ PreTrainedModel ,
18+ PreTrainedTokenizer ,
19+ PreTrainedTokenizerFast ,
20+ )
1321
1422from autointent import Embedder , Ranker , VectorIndex
1523from autointent .configs import CrossEncoderConfig , EmbedderConfig
1826ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
1927
2028ModuleAttributes : TypeAlias = (
21- ModuleSimpleAttributes | TagsList | np .ndarray | Embedder | VectorIndex | BaseEstimator | Ranker # type: ignore[type-arg]
29+ ModuleSimpleAttributes
30+ | TagsList
31+ | np .ndarray
32+ | Embedder
33+ | VectorIndex
34+ | BaseEstimator
35+ | Ranker
36+ | BaseModel
37+ | PreTrainedModel
38+ | PreTrainedTokenizer
39+ | PreTrainedTokenizerFast
2240)
2341
2442logger = logging .getLogger (__name__ )
@@ -33,6 +51,8 @@ class Dumper:
3351 estimators = "estimators"
3452 cross_encoders = "cross_encoders"
3553 pydantic_models : str = "pydantic"
54+ hf_models = "hf_models"
55+ hf_tokenizers = "hf_tokenizers"
3656
3757 @staticmethod
3858 def make_subdirectories (path : Path ) -> None :
@@ -48,12 +68,14 @@ def make_subdirectories(path: Path) -> None:
4868 path / Dumper .estimators ,
4969 path / Dumper .cross_encoders ,
5070 path / Dumper .pydantic_models ,
71+ path / Dumper .hf_models ,
72+ path / Dumper .hf_tokenizers ,
5173 ]
5274 for subdir in subdirectories :
5375 subdir .mkdir (parents = True , exist_ok = True )
5476
5577 @staticmethod
56- def dump (obj : Any , path : Path ) -> None : # noqa: ANN401, C901
78+ def dump (obj : Any , path : Path ) -> None : # noqa: ANN401, C901, PLR0912
5779 """Dump modules attributes to filestystem.
5880
5981 Args:
@@ -67,7 +89,26 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
6789 Dumper .make_subdirectories (path )
6890
6991 for key , val in attrs .items ():
70- if isinstance (val , TagsList ):
92+ if isinstance (val , PreTrainedModel ):
93+ try :
94+ model_path = path / Dumper .hf_models / key
95+ val .save_pretrained (model_path )
96+ except Exception :
97+ logger .exception ("Error dumping Hugging Face model %s" , key )
98+ elif isinstance (val , PreTrainedTokenizer | PreTrainedTokenizerFast ):
99+ try :
100+ tokenizer_path = path / Dumper .hf_tokenizers / key
101+ val .save_pretrained (tokenizer_path )
102+ except Exception :
103+ logger .exception ("Error dumping Hugging Face tokenizer %s" , key )
104+ elif isinstance (val , BaseModel ):
105+ try :
106+ pydantic_path = path / Dumper .pydantic_models / f"{ key } .json"
107+ with pydantic_path .open ("w" , encoding = "utf-8" ) as file :
108+ json .dump (val .model_dump (), file , ensure_ascii = False , indent = 4 )
109+ except Exception :
110+ logger .exception ("Error dumping pydantic model %s" , key )
111+ elif isinstance (val , TagsList ):
71112 val .dump (path / Dumper .tags / key )
72113 elif isinstance (val , ModuleSimpleAttributes ):
73114 simple_attrs [key ] = val
@@ -78,25 +119,23 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
78119 elif isinstance (val , VectorIndex ):
79120 val .dump (path / Dumper .indexes / key )
80121 elif isinstance (val , BaseEstimator ):
81- joblib .dump (val , path / Dumper .estimators / key )
122+ try :
123+ joblib .dump (val , path / Dumper .estimators / f"{ key } .joblib" )
124+ except Exception :
125+ logger .exception ("Error dumping BaseEstimator %s" , key )
82126 elif isinstance (val , Ranker ):
83127 val .save (str (path / Dumper .cross_encoders / key ))
84- elif isinstance (val , CrossEncoderConfig | EmbedderConfig ):
85- try :
86- pydantic_path = path / Dumper .pydantic_models / f"{ key } .json"
87- with pydantic_path .open ("w" , encoding = "utf-8" ) as file :
88- json .dump (val .model_dump (), file , ensure_ascii = False , indent = 4 )
89- except Exception as e :
90- msg = f"Error dumping pydantic model { key } : { e } "
91- logging .exception (msg )
92- else :
93- msg = f"Attribute { key } of type { type (val )} cannot be dumped to file system."
94- logger .error (msg )
95-
96- with (path / Dumper .simple_attrs ).open ("w" ) as file :
128+ elif not isinstance (val , type | types .ModuleType | types .FunctionType | types .MethodType ):
129+ logger .warning ("Attribute '%s' of type %s cannot be dumped and will be skipped." , key , type (val ))
130+
131+ with (path / Dumper .simple_attrs ).open ("w" , encoding = "utf-8" ) as file :
97132 json .dump (simple_attrs , file , ensure_ascii = False , indent = 4 )
98133
99- np .savez (path / Dumper .arrays , allow_pickle = False , ** arrays )
134+ if arrays :
135+ try :
136+ np .savez (path / Dumper .arrays , allow_pickle = False , ** arrays )
137+ except Exception :
138+ logger .exception ("Error saving numpy arrays to %s" , path / Dumper .arrays )
100139
101140 @staticmethod
102141 def load ( # noqa: PLR0912, C901, PLR0915
@@ -114,69 +153,115 @@ def load( # noqa: PLR0912, C901, PLR0915
114153 estimators : dict [str , Any ] = {}
115154 cross_encoders : dict [str , Any ] = {}
116155 pydantic_models : dict [str , Any ] = {}
156+ hf_models : dict [str , Any ] = {}
157+ hf_tokenizers : dict [str , Any ] = {}
117158
118159 for child in path .iterdir ():
119- if child .name == Dumper .tags :
120- tags = {tags_dump .name : TagsList .load (tags_dump ) for tags_dump in child .iterdir ()}
121- elif child .name == Dumper .simple_attrs :
122- with child .open () as file :
123- simple_attrs = json .load (file )
124- elif child .name == Dumper .arrays :
125- arrays = dict (np .load (child ))
126- elif child .name == Dumper .embedders :
127- embedders = {
128- embedder_dump .name : Embedder .load (embedder_dump , override_config = embedder_config )
129- for embedder_dump in child .iterdir ()
130- }
131- elif child .name == Dumper .indexes :
132- indexes = {index_dump .name : VectorIndex .load (index_dump ) for index_dump in child .iterdir ()}
133- elif child .name == Dumper .estimators :
134- estimators = {estimator_dump .name : joblib .load (estimator_dump ) for estimator_dump in child .iterdir ()}
135- elif child .name == Dumper .cross_encoders :
136- cross_encoders = {
137- cross_encoder_dump .name : Ranker .load (cross_encoder_dump , override_config = cross_encoder_config )
138- for cross_encoder_dump in child .iterdir ()
139- }
140- elif child .name == Dumper .pydantic_models :
141- for model_file in child .iterdir ():
142- with model_file .open ("r" , encoding = "utf-8" ) as file :
143- content = json .load (file )
144- variable_name = model_file .stem
145-
146- # First try to get the type annotation from the class annotations.
147- model_type = obj .__class__ .__annotations__ .get (variable_name )
148-
149- # Fallback: inspect __init__ signature if not found in class-level annotations.
150- if model_type is None :
151- sig = inspect .signature (obj .__init__ )
152- if variable_name in sig .parameters :
153- model_type = sig .parameters [variable_name ].annotation
154-
155- if model_type is None :
156- msg = f"No type annotation found for { variable_name } "
157- logger .error (msg )
158- continue
159-
160- # If the annotation is a Union, extract the pydantic model type.
161- if get_origin (model_type ) in (UnionType , Union ):
162- for arg in get_args (model_type ):
163- if isinstance (arg , type ) and issubclass (arg , BaseModel ):
164- model_type = arg
165- break
166- else :
167- msg = f"No pydantic type found in Union for { variable_name } "
168- logger .error (msg )
169- continue
170-
171- if not (isinstance (model_type , type ) and issubclass (model_type , BaseModel )):
172- msg = f"Type for { variable_name } is not a pydantic model: { model_type } "
173- logger .error (msg )
174- continue
175-
176- pydantic_models [variable_name ] = model_type (** content )
177- else :
178- msg = f"Found unexpected child { child } "
179- logger .error (msg )
160+ if child .is_file ():
161+ if child .name == Dumper .simple_attrs :
162+ try :
163+ with child .open (encoding = "utf-8" ) as file :
164+ simple_attrs = json .load (file )
165+ except Exception :
166+ logger .exception ("Error loading simple attributes from %s" , child )
167+ elif child .name == Dumper .arrays :
168+ try :
169+ arrays = dict (np .load (child , allow_pickle = False ))
170+ except Exception as e : # noqa: BLE001
171+ logger .warning ("Could not load numpy arrays from %s: %s" , child , e )
172+
173+ elif child .is_dir ():
174+ if child .name == Dumper .hf_models :
175+ for model_dir in child .iterdir ():
176+ if model_dir .is_dir ():
177+ attr_name = model_dir .name
178+ try :
179+ hf_models [attr_name ] = AutoModelForSequenceClassification .from_pretrained (model_dir )
180+ except Exception :
181+ logger .exception ("Error loading Hugging Face model '%s' from %s" , attr_name , model_dir )
182+ elif child .name == Dumper .hf_tokenizers :
183+ for tokenizer_dir in child .iterdir ():
184+ if tokenizer_dir .is_dir ():
185+ attr_name = tokenizer_dir .name
186+ try :
187+ hf_tokenizers [attr_name ] = AutoTokenizer .from_pretrained (tokenizer_dir )
188+ except Exception :
189+ logger .exception (
190+ "Error loading Hugging Face tokenizer '%s' from %s" , attr_name , tokenizer_dir
191+ )
192+ elif child .name == Dumper .pydantic_models :
193+ for model_file in child .iterdir ():
194+ if model_file .is_file () and model_file .suffix == ".json" :
195+ variable_name = model_file .stem
196+ try :
197+ with model_file .open ("r" , encoding = "utf-8" ) as file :
198+ content = json .load (file )
199+
200+ model_type = obj .__class__ .__annotations__ .get (variable_name )
201+
202+ if model_type is None :
203+ sig = inspect .signature (obj .__init__ )
204+ if variable_name in sig .parameters :
205+ model_type = sig .parameters [variable_name ].annotation
206+
207+ if model_type is None :
208+ logger .error ("No type annotation found for pydantic model %s" , variable_name )
209+ continue
210+
211+ potential_types = []
212+ if get_origin (model_type ) in (UnionType , Union ):
213+ potential_types .extend (get_args (model_type ))
214+ else :
215+ potential_types .append (model_type )
216+
217+ pydantic_type = None
218+ for p_type in potential_types :
219+ if inspect .isclass (p_type ) and issubclass (p_type , BaseModel ):
220+ pydantic_type = p_type
221+ break
222+
223+ if pydantic_type is None :
224+ logger .error ("No pydantic type found in annotation for %s" , variable_name )
225+ continue
226+
227+ pydantic_models [variable_name ] = pydantic_type (** content )
228+ except Exception :
229+ logger .exception ("Error loading pydantic model %s from %s" , variable_name , model_file )
230+
231+ elif child .name == Dumper .tags :
232+ tags = {tags_dump .name : TagsList .load (tags_dump ) for tags_dump in child .iterdir ()}
233+ elif child .name == Dumper .embedders :
234+ embedders = {
235+ embedder_dump .name : Embedder .load (embedder_dump , override_config = embedder_config )
236+ for embedder_dump in child .iterdir ()
237+ }
238+ elif child .name == Dumper .indexes :
239+ indexes = {index_dump .name : VectorIndex .load (index_dump ) for index_dump in child .iterdir ()}
240+ elif child .name == Dumper .estimators :
241+ estimators = {}
242+ for estimator_dump in child .iterdir ():
243+ if estimator_dump .is_file () and estimator_dump .suffix == ".joblib" :
244+ try :
245+ estimators [estimator_dump .stem ] = joblib .load (estimator_dump )
246+ except Exception :
247+ logger .exception (
248+ "Error loading estimator %s from %s" , estimator_dump .stem , estimator_dump
249+ )
250+ elif child .name == Dumper .cross_encoders :
251+ cross_encoders = {
252+ cross_encoder_dump .name : Ranker .load (cross_encoder_dump , override_config = cross_encoder_config )
253+ for cross_encoder_dump in child .iterdir ()
254+ }
255+
180256 obj .__dict__ .update (
181- tags | simple_attrs | arrays | embedders | indexes | estimators | cross_encoders | pydantic_models
257+ tags
258+ | simple_attrs
259+ | arrays
260+ | embedders
261+ | indexes
262+ | estimators
263+ | cross_encoders
264+ | pydantic_models
265+ | hf_models
266+ | hf_tokenizers
182267 )
0 commit comments