Skip to content

Commit 2c6ace8

Browse files
SeBorgeygithub-actions[bot]voorhsDarinochkaSamoed
authored
dumper saving (#180)
* added main code for saving models * Update optimizer_config.schema.json * checker fixes * Revert "checker fixes" This reverts commit 6e32eb9. * Revert "added main code for saving models" This reverts commit 5637fb8. * drat main code for new dumper * ruf fix * comments * added code for test dumper * Check dumper (#182) * Feat/code carbon each node (#175) * feat: update codecarbon * feat: update codecarbon * feat: added codecarbon * Update optimizer_config.schema.json * fix: fixed import mypy * fix: codecarbon package * fix: only float\integer log * fix: codecarbon package * fix: mypy * fix: test * fix: delete emissions * fix: test --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * standartize pyproject & speedup tests (#176) * speedup tests * fix pyproject * Update optimizer_config.schema.json * move optional dependencies * fixes * add xdist * fix ci * download data from hub in doc * add caching * add doc cache --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: voorhs <[email protected]> * add proper `omit` definition for tests coverage report (#179) * add proper `omit` definition * Update optimizer_config.schema.json * exclude tmp from coverage report --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * add node validators (#177) * add node validators * add comments * Update optimizer_config.schema.json * rename bert model * lint * fixes * fix test --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: voorhs <[email protected]> * update makefile * update bert test * mypy workaround * attempt to fix windows permission error * workaround --------- Co-authored-by: Darinochka <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Roman Solomatin <[email protected]> --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Алексеев Илья <[email protected]> Co-authored-by: Darinochka <[email protected]> Co-authored-by: Roman Solomatin <[email protected]>
1 parent 86384cf commit 2c6ace8

File tree

5 files changed

+123
-10
lines changed

5 files changed

+123
-10
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ poetry = poetry run
33

44
.PHONY: install
55
install:
6-
poetry install --with dev,test,typing,docs
6+
poetry install --extras "dev test typing docs"
77

88
.PHONY: test
99
test:
@@ -24,7 +24,7 @@ lint:
2424

2525
.PHONY: sync
2626
sync:
27-
poetry sync --with dev,test,typing,docs
27+
poetry sync --extras "dev test typing docs"
2828

2929
.PHONY: docs
3030
docs:

autointent/_dump_tools.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class Dumper:
3333
estimators = "estimators"
3434
cross_encoders = "cross_encoders"
3535
pydantic_models: str = "pydantic"
36+
hf_models = "hf_models"
37+
hf_tokenizers = "hf_tokenizers"
3638

3739
@staticmethod
3840
def make_subdirectories(path: Path) -> None:
@@ -48,12 +50,14 @@ def make_subdirectories(path: Path) -> None:
4850
path / Dumper.estimators,
4951
path / Dumper.cross_encoders,
5052
path / Dumper.pydantic_models,
53+
path / Dumper.hf_models,
54+
path / Dumper.hf_tokenizers,
5155
]
5256
for subdir in subdirectories:
5357
subdir.mkdir(parents=True, exist_ok=True)
5458

5559
@staticmethod
56-
def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
60+
def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901, PLR0912, PLR0915
5761
"""Dump modules attributes to filestystem.
5862
5963
Args:
@@ -89,6 +93,28 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
8993
except Exception as e:
9094
msg = f"Error dumping pydantic model {key}: {e}"
9195
logging.exception(msg)
96+
elif (key == "_model" or "model" in key.lower()) and hasattr(val, "save_pretrained"):
97+
model_path = path / Dumper.hf_models / key
98+
model_path.mkdir(parents=True, exist_ok=True)
99+
try:
100+
val.save_pretrained(model_path)
101+
class_info = {"module": val.__class__.__module__, "name": val.__class__.__name__}
102+
with (model_path / "class_info.json").open("w") as f:
103+
json.dump(class_info, f)
104+
except Exception as e:
105+
msg = f"Error dumping HF model {key}: {e}"
106+
logger.exception(msg)
107+
elif (key == "_tokenizer" or "tokenizer" in key.lower()) and hasattr(val, "save_pretrained"):
108+
tokenizer_path = path / Dumper.hf_tokenizers / key
109+
tokenizer_path.mkdir(parents=True, exist_ok=True)
110+
try:
111+
val.save_pretrained(tokenizer_path)
112+
class_info = {"module": val.__class__.__module__, "name": val.__class__.__name__}
113+
with (tokenizer_path / "class_info.json").open("w") as f:
114+
json.dump(class_info, f)
115+
except Exception as e:
116+
msg = f"Error dumping HF tokenizer {key}: {e}"
117+
logger.exception(msg)
92118
else:
93119
msg = f"Attribute {key} of type {type(val)} cannot be dumped to file system."
94120
logger.error(msg)
@@ -114,6 +140,8 @@ def load( # noqa: PLR0912, C901, PLR0915
114140
estimators: dict[str, Any] = {}
115141
cross_encoders: dict[str, Any] = {}
116142
pydantic_models: dict[str, Any] = {}
143+
hf_models: dict[str, Any] = {}
144+
hf_tokenizers: dict[str, Any] = {}
117145

118146
for child in path.iterdir():
119147
if child.name == Dumper.tags:
@@ -151,7 +179,6 @@ def load( # noqa: PLR0912, C901, PLR0915
151179
sig = inspect.signature(obj.__init__)
152180
if variable_name in sig.parameters:
153181
model_type = sig.parameters[variable_name].annotation
154-
155182
if model_type is None:
156183
msg = f"No type annotation found for {variable_name}"
157184
logger.error(msg)
@@ -174,9 +201,45 @@ def load( # noqa: PLR0912, C901, PLR0915
174201
continue
175202

176203
pydantic_models[variable_name] = model_type(**content)
204+
elif child.name == Dumper.hf_models:
205+
for model_dir in child.iterdir():
206+
try:
207+
with (model_dir / "class_info.json").open("r") as f:
208+
class_info = json.load(f)
209+
210+
module = __import__(class_info["module"], fromlist=[class_info["name"]])
211+
model_class = getattr(module, class_info["name"])
212+
213+
hf_models[model_dir.name] = model_class.from_pretrained(model_dir)
214+
except Exception as e: # noqa: PERF203
215+
msg = f"Error loading HF model {model_dir.name}: {e}"
216+
logger.exception(msg)
217+
elif child.name == Dumper.hf_tokenizers:
218+
for tokenizer_dir in child.iterdir():
219+
try:
220+
with (tokenizer_dir / "class_info.json").open("r") as f:
221+
class_info = json.load(f)
222+
223+
module = __import__(class_info["module"], fromlist=[class_info["name"]])
224+
tokenizer_class = getattr(module, class_info["name"])
225+
226+
hf_tokenizers[tokenizer_dir.name] = tokenizer_class.from_pretrained(tokenizer_dir)
227+
except Exception as e: # noqa: PERF203
228+
msg = f"Error loading HF tokenizer {tokenizer_dir.name}: {e}"
229+
logger.exception(msg)
177230
else:
178231
msg = f"Found unexpected child {child}"
179232
logger.error(msg)
233+
180234
obj.__dict__.update(
181-
tags | simple_attrs | arrays | embedders | indexes | estimators | cross_encoders | pydantic_models
235+
tags
236+
| simple_attrs
237+
| arrays
238+
| embedders
239+
| indexes
240+
| estimators
241+
| cross_encoders
242+
| pydantic_models
243+
| hf_models
244+
| hf_tokenizers
182245
)

autointent/context/data_handler/_stratification.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from numpy import typing as npt
1313
from sklearn.model_selection import train_test_split
1414
from skmultilearn.model_selection import IterativeStratification
15-
from transformers import set_seed
15+
from transformers import set_seed # type: ignore[attr-defined]
1616

1717
from autointent import Dataset
1818
from autointent.custom_types import LabelType
@@ -128,7 +128,8 @@ def _split_multilabel(self, dataset: HFDataset, test_size: float) -> Sequence[np
128128
Returns:
129129
A sequence containing indices for train and test splits.
130130
"""
131-
set_seed(self.random_seed) # workaround for buggy nature of IterativeStratification from skmultilearn
131+
if self.random_seed is not None:
132+
set_seed(self.random_seed) # workaround for buggy nature of IterativeStratification from skmultilearn
132133
splitter = IterativeStratification(
133134
n_splits=2,
134135
order=2,

autointent/modules/scoring/_bert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy.typing as npt
88
import torch
99
from datasets import Dataset
10-
from transformers import (
10+
from transformers import ( # type: ignore[attr-defined]
1111
AutoModelForSequenceClassification,
1212
AutoTokenizer,
1313
DataCollatorWithPadding,
@@ -127,15 +127,15 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
127127
use_cpu=use_cpu,
128128
)
129129

130-
trainer = Trainer(
130+
trainer = Trainer( # type: ignore[no-untyped-call]
131131
model=self._model,
132132
args=training_args,
133133
train_dataset=tokenized_dataset,
134134
tokenizer=self._tokenizer,
135135
data_collator=DataCollatorWithPadding(tokenizer=self._tokenizer),
136136
)
137137

138-
trainer.train()
138+
trainer.train() # type: ignore[attr-defined]
139139

140140
self._model.eval()
141141

tests/modules/scoring/test_bert.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,59 @@
1+
import shutil
2+
import tempfile
3+
from pathlib import Path
4+
15
import numpy as np
26
import pytest
37

48
from autointent.context.data_handler import DataHandler
59
from autointent.modules import BertScorer
610

711

12+
def test_bert_scorer_dump_load(dataset):
13+
"""Test that BertScorer can be saved and loaded while preserving predictions."""
14+
data_handler = DataHandler(dataset)
15+
16+
# Create and train scorer
17+
scorer_original = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
18+
scorer_original.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
19+
20+
# Test data
21+
test_data = [
22+
"why is there a hold on my account",
23+
"why is my bank account frozen",
24+
]
25+
26+
# Get predictions before saving
27+
predictions_before = scorer_original.predict(test_data)
28+
29+
# Create temp directory and save model
30+
temp_dir_path = Path(tempfile.mkdtemp(prefix="bert_scorer_test_"))
31+
try:
32+
# Save the model
33+
scorer_original.dump(str(temp_dir_path))
34+
35+
# Create a new scorer and load saved model
36+
scorer_loaded = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
37+
scorer_loaded.load(str(temp_dir_path))
38+
39+
# Verify model and tokenizer are loaded
40+
assert hasattr(scorer_loaded, "_model")
41+
assert scorer_loaded._model is not None
42+
assert hasattr(scorer_loaded, "_tokenizer")
43+
assert scorer_loaded._tokenizer is not None
44+
45+
# Get predictions after loading
46+
predictions_after = scorer_loaded.predict(test_data)
47+
48+
# Verify predictions match
49+
assert predictions_before.shape == predictions_after.shape
50+
np.testing.assert_allclose(predictions_before, predictions_after, atol=1e-6)
51+
52+
finally:
53+
# Clean up
54+
shutil.rmtree(temp_dir_path, ignore_errors=True) # workaround for windows permission error
55+
56+
857
def test_bert_prediction(dataset):
958
"""Test that the transformer model can fit and make predictions."""
1059
data_handler = DataHandler(dataset)

0 commit comments

Comments
 (0)