Skip to content

Commit 5637fb8

Browse files
committed
added main code for saving models
1 parent 29de65d commit 5637fb8

File tree

2 files changed

+184
-99
lines changed

2 files changed

+184
-99
lines changed

autointent/_dump_tools.py

Lines changed: 165 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
22
import json
33
import logging
4+
import types
45
from pathlib import Path
56
from types import UnionType
67
from typing import Any, TypeAlias, Union, get_args, get_origin
@@ -10,6 +11,13 @@
1011
import numpy.typing as npt
1112
from pydantic import BaseModel
1213
from sklearn.base import BaseEstimator
14+
from transformers import (
15+
AutoModelForSequenceClassification,
16+
AutoTokenizer,
17+
PreTrainedModel,
18+
PreTrainedTokenizer,
19+
PreTrainedTokenizerFast,
20+
)
1321

1422
from autointent import Embedder, Ranker, VectorIndex
1523
from autointent.configs import CrossEncoderConfig, EmbedderConfig
@@ -18,7 +26,17 @@
1826
ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
1927

2028
ModuleAttributes: TypeAlias = (
21-
ModuleSimpleAttributes | TagsList | np.ndarray | Embedder | VectorIndex | BaseEstimator | Ranker # type: ignore[type-arg]
29+
ModuleSimpleAttributes
30+
| TagsList
31+
| np.ndarray
32+
| Embedder
33+
| VectorIndex
34+
| BaseEstimator
35+
| Ranker
36+
| BaseModel
37+
| PreTrainedModel
38+
| PreTrainedTokenizer
39+
| PreTrainedTokenizerFast
2240
)
2341

2442
logger = logging.getLogger(__name__)
@@ -33,6 +51,8 @@ class Dumper:
3351
estimators = "estimators"
3452
cross_encoders = "cross_encoders"
3553
pydantic_models: str = "pydantic"
54+
hf_models = "hf_models"
55+
hf_tokenizers = "hf_tokenizers"
3656

3757
@staticmethod
3858
def make_subdirectories(path: Path) -> None:
@@ -48,12 +68,14 @@ def make_subdirectories(path: Path) -> None:
4868
path / Dumper.estimators,
4969
path / Dumper.cross_encoders,
5070
path / Dumper.pydantic_models,
71+
path / Dumper.hf_models,
72+
path / Dumper.hf_tokenizers,
5173
]
5274
for subdir in subdirectories:
5375
subdir.mkdir(parents=True, exist_ok=True)
5476

5577
@staticmethod
56-
def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
78+
def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901, PLR0912
5779
"""Dump modules attributes to filestystem.
5880
5981
Args:
@@ -67,7 +89,26 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
6789
Dumper.make_subdirectories(path)
6890

6991
for key, val in attrs.items():
70-
if isinstance(val, TagsList):
92+
if isinstance(val, PreTrainedModel):
93+
try:
94+
model_path = path / Dumper.hf_models / key
95+
val.save_pretrained(model_path)
96+
except Exception:
97+
logger.exception("Error dumping Hugging Face model %s", key)
98+
elif isinstance(val, PreTrainedTokenizer | PreTrainedTokenizerFast):
99+
try:
100+
tokenizer_path = path / Dumper.hf_tokenizers / key
101+
val.save_pretrained(tokenizer_path)
102+
except Exception:
103+
logger.exception("Error dumping Hugging Face tokenizer %s", key)
104+
elif isinstance(val, BaseModel):
105+
try:
106+
pydantic_path = path / Dumper.pydantic_models / f"{key}.json"
107+
with pydantic_path.open("w", encoding="utf-8") as file:
108+
json.dump(val.model_dump(), file, ensure_ascii=False, indent=4)
109+
except Exception:
110+
logger.exception("Error dumping pydantic model %s", key)
111+
elif isinstance(val, TagsList):
71112
val.dump(path / Dumper.tags / key)
72113
elif isinstance(val, ModuleSimpleAttributes):
73114
simple_attrs[key] = val
@@ -78,25 +119,23 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
78119
elif isinstance(val, VectorIndex):
79120
val.dump(path / Dumper.indexes / key)
80121
elif isinstance(val, BaseEstimator):
81-
joblib.dump(val, path / Dumper.estimators / key)
122+
try:
123+
joblib.dump(val, path / Dumper.estimators / f"{key}.joblib")
124+
except Exception:
125+
logger.exception("Error dumping BaseEstimator %s", key)
82126
elif isinstance(val, Ranker):
83127
val.save(str(path / Dumper.cross_encoders / key))
84-
elif isinstance(val, CrossEncoderConfig | EmbedderConfig):
85-
try:
86-
pydantic_path = path / Dumper.pydantic_models / f"{key}.json"
87-
with pydantic_path.open("w", encoding="utf-8") as file:
88-
json.dump(val.model_dump(), file, ensure_ascii=False, indent=4)
89-
except Exception as e:
90-
msg = f"Error dumping pydantic model {key}: {e}"
91-
logging.exception(msg)
92-
else:
93-
msg = f"Attribute {key} of type {type(val)} cannot be dumped to file system."
94-
logger.error(msg)
95-
96-
with (path / Dumper.simple_attrs).open("w") as file:
128+
elif not isinstance(val, type | types.ModuleType | types.FunctionType | types.MethodType):
129+
logger.warning("Attribute '%s' of type %s cannot be dumped and will be skipped.", key, type(val))
130+
131+
with (path / Dumper.simple_attrs).open("w", encoding="utf-8") as file:
97132
json.dump(simple_attrs, file, ensure_ascii=False, indent=4)
98133

99-
np.savez(path / Dumper.arrays, allow_pickle=False, **arrays)
134+
if arrays:
135+
try:
136+
np.savez(path / Dumper.arrays, allow_pickle=False, **arrays)
137+
except Exception:
138+
logger.exception("Error saving numpy arrays to %s", path / Dumper.arrays)
100139

101140
@staticmethod
102141
def load( # noqa: PLR0912, C901, PLR0915
@@ -114,69 +153,115 @@ def load( # noqa: PLR0912, C901, PLR0915
114153
estimators: dict[str, Any] = {}
115154
cross_encoders: dict[str, Any] = {}
116155
pydantic_models: dict[str, Any] = {}
156+
hf_models: dict[str, Any] = {}
157+
hf_tokenizers: dict[str, Any] = {}
117158

118159
for child in path.iterdir():
119-
if child.name == Dumper.tags:
120-
tags = {tags_dump.name: TagsList.load(tags_dump) for tags_dump in child.iterdir()}
121-
elif child.name == Dumper.simple_attrs:
122-
with child.open() as file:
123-
simple_attrs = json.load(file)
124-
elif child.name == Dumper.arrays:
125-
arrays = dict(np.load(child))
126-
elif child.name == Dumper.embedders:
127-
embedders = {
128-
embedder_dump.name: Embedder.load(embedder_dump, override_config=embedder_config)
129-
for embedder_dump in child.iterdir()
130-
}
131-
elif child.name == Dumper.indexes:
132-
indexes = {index_dump.name: VectorIndex.load(index_dump) for index_dump in child.iterdir()}
133-
elif child.name == Dumper.estimators:
134-
estimators = {estimator_dump.name: joblib.load(estimator_dump) for estimator_dump in child.iterdir()}
135-
elif child.name == Dumper.cross_encoders:
136-
cross_encoders = {
137-
cross_encoder_dump.name: Ranker.load(cross_encoder_dump, override_config=cross_encoder_config)
138-
for cross_encoder_dump in child.iterdir()
139-
}
140-
elif child.name == Dumper.pydantic_models:
141-
for model_file in child.iterdir():
142-
with model_file.open("r", encoding="utf-8") as file:
143-
content = json.load(file)
144-
variable_name = model_file.stem
145-
146-
# First try to get the type annotation from the class annotations.
147-
model_type = obj.__class__.__annotations__.get(variable_name)
148-
149-
# Fallback: inspect __init__ signature if not found in class-level annotations.
150-
if model_type is None:
151-
sig = inspect.signature(obj.__init__)
152-
if variable_name in sig.parameters:
153-
model_type = sig.parameters[variable_name].annotation
154-
155-
if model_type is None:
156-
msg = f"No type annotation found for {variable_name}"
157-
logger.error(msg)
158-
continue
159-
160-
# If the annotation is a Union, extract the pydantic model type.
161-
if get_origin(model_type) in (UnionType, Union):
162-
for arg in get_args(model_type):
163-
if isinstance(arg, type) and issubclass(arg, BaseModel):
164-
model_type = arg
165-
break
166-
else:
167-
msg = f"No pydantic type found in Union for {variable_name}"
168-
logger.error(msg)
169-
continue
170-
171-
if not (isinstance(model_type, type) and issubclass(model_type, BaseModel)):
172-
msg = f"Type for {variable_name} is not a pydantic model: {model_type}"
173-
logger.error(msg)
174-
continue
175-
176-
pydantic_models[variable_name] = model_type(**content)
177-
else:
178-
msg = f"Found unexpected child {child}"
179-
logger.error(msg)
160+
if child.is_file():
161+
if child.name == Dumper.simple_attrs:
162+
try:
163+
with child.open(encoding="utf-8") as file:
164+
simple_attrs = json.load(file)
165+
except Exception:
166+
logger.exception("Error loading simple attributes from %s", child)
167+
elif child.name == Dumper.arrays:
168+
try:
169+
arrays = dict(np.load(child, allow_pickle=False))
170+
except Exception as e: # noqa: BLE001
171+
logger.warning("Could not load numpy arrays from %s: %s", child, e)
172+
173+
elif child.is_dir():
174+
if child.name == Dumper.hf_models:
175+
for model_dir in child.iterdir():
176+
if model_dir.is_dir():
177+
attr_name = model_dir.name
178+
try:
179+
hf_models[attr_name] = AutoModelForSequenceClassification.from_pretrained(model_dir)
180+
except Exception:
181+
logger.exception("Error loading Hugging Face model '%s' from %s", attr_name, model_dir)
182+
elif child.name == Dumper.hf_tokenizers:
183+
for tokenizer_dir in child.iterdir():
184+
if tokenizer_dir.is_dir():
185+
attr_name = tokenizer_dir.name
186+
try:
187+
hf_tokenizers[attr_name] = AutoTokenizer.from_pretrained(tokenizer_dir)
188+
except Exception:
189+
logger.exception(
190+
"Error loading Hugging Face tokenizer '%s' from %s", attr_name, tokenizer_dir
191+
)
192+
elif child.name == Dumper.pydantic_models:
193+
for model_file in child.iterdir():
194+
if model_file.is_file() and model_file.suffix == ".json":
195+
variable_name = model_file.stem
196+
try:
197+
with model_file.open("r", encoding="utf-8") as file:
198+
content = json.load(file)
199+
200+
model_type = obj.__class__.__annotations__.get(variable_name)
201+
202+
if model_type is None:
203+
sig = inspect.signature(obj.__init__)
204+
if variable_name in sig.parameters:
205+
model_type = sig.parameters[variable_name].annotation
206+
207+
if model_type is None:
208+
logger.error("No type annotation found for pydantic model %s", variable_name)
209+
continue
210+
211+
potential_types = []
212+
if get_origin(model_type) in (UnionType, Union):
213+
potential_types.extend(get_args(model_type))
214+
else:
215+
potential_types.append(model_type)
216+
217+
pydantic_type = None
218+
for p_type in potential_types:
219+
if inspect.isclass(p_type) and issubclass(p_type, BaseModel):
220+
pydantic_type = p_type
221+
break
222+
223+
if pydantic_type is None:
224+
logger.error("No pydantic type found in annotation for %s", variable_name)
225+
continue
226+
227+
pydantic_models[variable_name] = pydantic_type(**content)
228+
except Exception:
229+
logger.exception("Error loading pydantic model %s from %s", variable_name, model_file)
230+
231+
elif child.name == Dumper.tags:
232+
tags = {tags_dump.name: TagsList.load(tags_dump) for tags_dump in child.iterdir()}
233+
elif child.name == Dumper.embedders:
234+
embedders = {
235+
embedder_dump.name: Embedder.load(embedder_dump, override_config=embedder_config)
236+
for embedder_dump in child.iterdir()
237+
}
238+
elif child.name == Dumper.indexes:
239+
indexes = {index_dump.name: VectorIndex.load(index_dump) for index_dump in child.iterdir()}
240+
elif child.name == Dumper.estimators:
241+
estimators = {}
242+
for estimator_dump in child.iterdir():
243+
if estimator_dump.is_file() and estimator_dump.suffix == ".joblib":
244+
try:
245+
estimators[estimator_dump.stem] = joblib.load(estimator_dump)
246+
except Exception:
247+
logger.exception(
248+
"Error loading estimator %s from %s", estimator_dump.stem, estimator_dump
249+
)
250+
elif child.name == Dumper.cross_encoders:
251+
cross_encoders = {
252+
cross_encoder_dump.name: Ranker.load(cross_encoder_dump, override_config=cross_encoder_config)
253+
for cross_encoder_dump in child.iterdir()
254+
}
255+
180256
obj.__dict__.update(
181-
tags | simple_attrs | arrays | embedders | indexes | estimators | cross_encoders | pydantic_models
257+
tags
258+
| simple_attrs
259+
| arrays
260+
| embedders
261+
| indexes
262+
| estimators
263+
| cross_encoders
264+
| pydantic_models
265+
| hf_models
266+
| hf_tokenizers
182267
)

autointent/modules/scoring/_bert.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy.typing as npt
88
import torch
99
from datasets import Dataset
10+
from sklearn.preprocessing import LabelEncoder
1011
from transformers import (
1112
AutoModelForSequenceClassification,
1213
AutoTokenizer,
@@ -79,6 +80,10 @@ def fit(
7980
) -> None:
8081
if hasattr(self, "_model"):
8182
self.clear_cache()
83+
if not isinstance(labels[0], list) and isinstance(labels[0], str):
84+
self._label_encoder = LabelEncoder()
85+
encoded_labels = self._label_encoder.fit_transform(labels)
86+
labels = encoded_labels.tolist()
8287
self._validate_task(labels)
8388

8489
model_name = self.model_config.model_name
@@ -88,30 +93,20 @@ def fit(
8893
id2label = {i: i for i in range(self._n_classes)}
8994

9095
self._model = AutoModelForSequenceClassification.from_pretrained(
91-
model_name,
92-
num_labels=self._n_classes,
93-
label2id=label2id,
94-
id2label=id2label,
95-
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
96+
model_name, num_labels=self._n_classes, label2id=label2id, id2label=id2label
9697
)
9798

9899
use_cpu = self.model_config.device == "cpu"
99100

100-
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
101-
return self._tokenizer( # type: ignore[no-any-return]
102-
examples["text"], return_tensors="pt", **self.model_config.tokenizer_config.model_dump()
103-
)
104-
105101
dataset = Dataset.from_dict({"text": utterances, "labels": labels})
106102

107-
if self._multilabel:
108-
# hugging face uses F.binary_cross_entropy_with_logits under the hood
109-
# which requires target labels to be of float type
110-
dataset = dataset.map(
111-
lambda example: {"label": torch.tensor(example["labels"], dtype=torch.float)}, remove_columns="labels"
112-
)
103+
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
104+
tokenizer_options = self.model_config.tokenizer_config.model_dump()
105+
tokenizer_options.pop("padding", None)
106+
tokenizer_options.pop("truncation", None)
107+
return self._tokenizer(examples["text"], truncation=True, padding=False, **tokenizer_options)
113108

114-
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=self.batch_size)
109+
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
115110

116111
with tempfile.TemporaryDirectory() as tmp_dir:
117112
training_args = TrainingArguments(
@@ -127,12 +122,14 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
127122
use_cpu=use_cpu,
128123
)
129124

125+
data_collator = DataCollatorWithPadding(tokenizer=self._tokenizer)
126+
130127
trainer = Trainer(
131128
model=self._model,
132129
args=training_args,
133130
train_dataset=tokenized_dataset,
134131
tokenizer=self._tokenizer,
135-
data_collator=DataCollatorWithPadding(tokenizer=self._tokenizer),
132+
data_collator=data_collator,
136133
)
137134

138135
trainer.train()
@@ -146,9 +143,12 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
146143

147144
device = next(self._model.parameters()).device
148145
all_predictions = []
146+
tokenizer_options = self.model_config.tokenizer_config.model_dump()
147+
tokenizer_options.pop("padding", None)
148+
tokenizer_options.pop("truncation", None)
149149
for i in range(0, len(utterances), self.batch_size):
150150
batch = utterances[i : i + self.batch_size]
151-
inputs = self._tokenizer(batch, return_tensors="pt", **self.model_config.tokenizer_config.model_dump())
151+
inputs = self._tokenizer(batch, return_tensors="pt", padding=True, truncation=True, **tokenizer_options)
152152
inputs = {k: v.to(device) for k, v in inputs.items()}
153153
with torch.no_grad():
154154
outputs = self._model(**inputs)

0 commit comments

Comments
 (0)