Skip to content

Commit 1014e83

Browse files
authored
Fix/bert models loading for inference (#205)
* refactor `load_search_space` * first attempt fix the issue * change transformers resolving logic * define `get_implicit_initialization_params` in child classes * fix dump tools * fix typing errors * small adjustments
1 parent 5a9a192 commit 1014e83

File tree

21 files changed

+165
-146
lines changed

21 files changed

+165
-146
lines changed

autointent/_dump_tools.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,16 @@
77
import joblib
88
import numpy as np
99
import numpy.typing as npt
10+
from peft import PeftModel
1011
from pydantic import BaseModel
1112
from sklearn.base import BaseEstimator
13+
from transformers import ( # type: ignore[attr-defined]
14+
AutoModelForSequenceClassification,
15+
AutoTokenizer,
16+
PreTrainedModel,
17+
PreTrainedTokenizer,
18+
PreTrainedTokenizerFast,
19+
)
1220

1321
from autointent import Embedder, Ranker, VectorIndex
1422
from autointent.configs import CrossEncoderConfig, EmbedderConfig
@@ -34,6 +42,7 @@ class Dumper:
3442
pydantic_models: str = "pydantic"
3543
hf_models = "hf_models"
3644
hf_tokenizers = "hf_tokenizers"
45+
ptuning_models = "ptuning_models"
3746

3847
@staticmethod
3948
def make_subdirectories(path: Path, exists_ok: bool = False) -> None:
@@ -52,6 +61,7 @@ def make_subdirectories(path: Path, exists_ok: bool = False) -> None:
5261
path / Dumper.pydantic_models,
5362
path / Dumper.hf_models,
5463
path / Dumper.hf_tokenizers,
64+
path / Dumper.ptuning_models,
5565
]
5666
for subdir in subdirectories:
5767
subdir.mkdir(parents=True, exist_ok=exists_ok)
@@ -101,25 +111,38 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
101111
except Exception as e:
102112
msg = f"Error dumping pydantic model {key}: {e}"
103113
logging.exception(msg)
104-
elif (key == "_model" or "model" in key.lower()) and hasattr(val, "save_pretrained"):
114+
elif isinstance(val, PeftModel):
115+
# dumping peft models is a nightmare...
116+
# this might break with new versions of peft
117+
try:
118+
if val._is_prompt_learning: # noqa: SLF001
119+
# strategy to save prompt learning models: save prompt encoder and bert classifier separately
120+
model_path = path / Dumper.ptuning_models / key
121+
model_path.mkdir(parents=True, exist_ok=True)
122+
val.save_pretrained(str(model_path / "peft"))
123+
val.base_model.save_pretrained(model_path / "base_model") # type: ignore[attr-defined]
124+
else:
125+
# strategy to save lora models: merge adapters and save as usual hugging face model
126+
model_path = path / Dumper.hf_models / key
127+
model_path.mkdir(parents=True, exist_ok=True)
128+
merged_model: PreTrainedModel = val.merge_and_unload()
129+
merged_model.save_pretrained(model_path) # type: ignore[attr-defined]
130+
except Exception as e:
131+
msg = f"Error dumping PeftModel {key}: {e}"
132+
logger.exception(msg)
133+
elif isinstance(val, PreTrainedModel):
105134
model_path = path / Dumper.hf_models / key
106135
model_path.mkdir(parents=True, exist_ok=True)
107136
try:
108-
val.save_pretrained(model_path)
109-
class_info = {"module": val.__class__.__module__, "name": val.__class__.__name__}
110-
with (model_path / "class_info.json").open("w") as f:
111-
json.dump(class_info, f)
137+
val.save_pretrained(model_path) # type: ignore[attr-defined]
112138
except Exception as e:
113139
msg = f"Error dumping HF model {key}: {e}"
114140
logger.exception(msg)
115-
elif (key == "_tokenizer" or "tokenizer" in key.lower()) and hasattr(val, "save_pretrained"):
141+
elif isinstance(val, PreTrainedTokenizer | PreTrainedTokenizerFast):
116142
tokenizer_path = path / Dumper.hf_tokenizers / key
117143
tokenizer_path.mkdir(parents=True, exist_ok=True)
118144
try:
119-
val.save_pretrained(tokenizer_path)
120-
class_info = {"module": val.__class__.__module__, "name": val.__class__.__name__}
121-
with (tokenizer_path / "class_info.json").open("w") as f:
122-
json.dump(class_info, f)
145+
val.save_pretrained(tokenizer_path) # type: ignore[union-attr]
123146
except Exception as e:
124147
msg = f"Error dumping HF tokenizer {key}: {e}"
125148
logger.exception(msg)
@@ -202,29 +225,25 @@ def load( # noqa: C901, PLR0912, PLR0915
202225
msg = f"Error loading Pydantic model from {model_dir}: {e}"
203226
logger.exception(msg)
204227
continue
228+
elif child.name == Dumper.ptuning_models:
229+
for model_dir in child.iterdir():
230+
try:
231+
model = AutoModelForSequenceClassification.from_pretrained(model_dir / "base_model")
232+
hf_models[model_dir.name] = PeftModel.from_pretrained(model, model_dir / "peft")
233+
except Exception as e: # noqa: PERF203
234+
msg = f"Error loading PeftModel {model_dir.name}: {e}"
235+
logger.exception(msg)
205236
elif child.name == Dumper.hf_models:
206237
for model_dir in child.iterdir():
207238
try:
208-
with (model_dir / "class_info.json").open("r") as f:
209-
class_info = json.load(f)
210-
211-
module = __import__(class_info["module"], fromlist=[class_info["name"]])
212-
model_class = getattr(module, class_info["name"])
213-
214-
hf_models[model_dir.name] = model_class.from_pretrained(model_dir)
239+
hf_models[model_dir.name] = AutoModelForSequenceClassification.from_pretrained(model_dir)
215240
except Exception as e: # noqa: PERF203
216241
msg = f"Error loading HF model {model_dir.name}: {e}"
217242
logger.exception(msg)
218243
elif child.name == Dumper.hf_tokenizers:
219244
for tokenizer_dir in child.iterdir():
220245
try:
221-
with (tokenizer_dir / "class_info.json").open("r") as f:
222-
class_info = json.load(f)
223-
224-
module = __import__(class_info["module"], fromlist=[class_info["name"]])
225-
tokenizer_class = getattr(module, class_info["name"])
226-
227-
hf_tokenizers[tokenizer_dir.name] = tokenizer_class.from_pretrained(tokenizer_dir)
246+
hf_tokenizers[tokenizer_dir.name] = AutoTokenizer.from_pretrained(tokenizer_dir)
228247
except Exception as e: # noqa: PERF203
229248
msg = f"Error loading HF tokenizer {tokenizer_dir.name}: {e}"
230249
logger.exception(msg)

autointent/context/_context.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from autointent import Dataset
99
from autointent._callbacks import CallbackHandler, get_callbacks
10-
from autointent.configs import CrossEncoderConfig, DataConfig, EmbedderConfig, LoggingConfig
10+
from autointent.configs import CrossEncoderConfig, DataConfig, EmbedderConfig, HFModelConfig, LoggingConfig
1111

1212
from .data_handler import DataHandler
1313
from .optimization_info import OptimizationInfo
@@ -49,7 +49,7 @@ def configure_logging(self, config: LoggingConfig) -> None:
4949
self.callback_handler = get_callbacks(config.report_to)
5050
self.optimization_info = OptimizationInfo()
5151

52-
def configure_transformer(self, config: EmbedderConfig | CrossEncoderConfig) -> None:
52+
def configure_transformer(self, config: EmbedderConfig | CrossEncoderConfig | HFModelConfig) -> None:
5353
"""Configure the vector index client and embedder.
5454
5555
Args:
@@ -59,6 +59,8 @@ def configure_transformer(self, config: EmbedderConfig | CrossEncoderConfig) ->
5959
self.embedder_config = config
6060
elif isinstance(config, CrossEncoderConfig):
6161
self.cross_encoder_config = config
62+
elif isinstance(config, HFModelConfig):
63+
self.transformer_config = config
6264

6365
def set_dataset(self, dataset: Dataset, config: DataConfig) -> None:
6466
"""Set the datasets for training, validation and testing.
@@ -133,31 +135,40 @@ def has_saved_modules(self) -> bool:
133135
def resolve_embedder(self) -> EmbedderConfig:
134136
"""Resolve the embedder configuration.
135137
136-
Returns the best embedder configuration or default configuration.
137-
138-
Raises:
139-
RuntimeError: If embedder configuration cannot be resolved.
138+
This method returns the configuration with the following priorities:
139+
- the best embedder configuration obtained during embedding node optimization
140+
- default configuration preset by user with :py:meth:`Context.configure_transformer`
141+
- default configuration preset by AutoIntent in :py:class:`autointent.configs.EmbedderConfig`
140142
"""
141143
try:
142144
return self.optimization_info.get_best_embedder()
143-
except ValueError as e:
145+
except ValueError:
144146
if hasattr(self, "embedder_config"):
145147
return self.embedder_config
146-
msg = (
147-
"Embedder could't be resolved. Either include embedding node into the "
148-
"search space or set default config with Context.configure_transformer."
149-
)
150-
raise RuntimeError(msg) from e
148+
return EmbedderConfig()
151149

152150
def resolve_ranker(self) -> CrossEncoderConfig:
153151
"""Resolve the cross-encoder configuration.
154152
155-
Returns default config if set.
156-
157-
Raises:
158-
RuntimeError: If cross-encoder configuration cannot be resolved.
153+
This method returns the configuration with the following priorities:
154+
- default configuration preset by user with :py:meth:`Context.configure_transformer`
155+
- default configuration preset by AutoIntent in :py:class:`autointent.configs.CrossEncoderConfig`
159156
"""
160157
if hasattr(self, "cross_encoder_config"):
161158
return self.cross_encoder_config
162-
msg = "Cross-encoder could't be resolved. Set default config with Context.configure_transformer."
163-
raise RuntimeError(msg)
159+
return CrossEncoderConfig()
160+
161+
def resolve_transformer(self) -> HFModelConfig:
162+
"""Resolve the transformer configuration.
163+
164+
This method returns the configuration with the following priorities:
165+
- the best transformer configuration obtained during embedding node optimization
166+
- default configuration preset by user with :py:meth:`Context.configure_transformer`
167+
- default configuration preset by AutoIntent in :py:class:`autointent.configs.HFModelConfig`
168+
"""
169+
try:
170+
return self.optimization_info.get_best_embedder()
171+
except ValueError:
172+
if hasattr(self, "transformer_config"):
173+
return self.transformer_config
174+
return HFModelConfig()

autointent/modules/base/_base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,16 @@ def from_context(cls, context: Context, **kwargs: dict[str, Any]) -> "BaseModule
138138
Initialized module
139139
"""
140140

141-
def get_embedder_config(self) -> dict[str, Any] | None:
142-
"""Get the config of the embedder."""
143-
return None
141+
@abstractmethod
142+
def get_implicit_initialization_params(self) -> dict[str, Any]:
143+
"""Return default params used in ``__init__`` method.
144+
145+
Some parameters of the module may be inferred using context rather from ``__init__`` method.
146+
But they need to be logged for reproducibility during loading from disk.
147+
148+
Returns:
149+
Dictionary of default params
150+
"""
144151

145152
@staticmethod
146153
def score_metrics_ho(params: tuple[Any, Any], metrics_dict: dict[str, Any]) -> dict[str, float]:

autointent/modules/base/_decision.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
class BaseDecision(BaseModule, ABC):
1919
"""Base class for decision modules."""
2020

21+
def get_implicit_initialization_params(self) -> dict[str, Any]:
22+
return {}
23+
2124
@abstractmethod
2225
def fit(
2326
self,

autointent/modules/base/_embedding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Base class for embedding modules."""
22

33
from abc import ABC
4+
from typing import Any
45

56
from autointent import Context
67
from autointent.custom_types import ListOfLabels
@@ -10,6 +11,9 @@
1011
class BaseEmbedding(BaseModule, ABC):
1112
"""Base class for embedding modules."""
1213

14+
def get_implicit_initialization_params(self) -> dict[str, Any]:
15+
return {}
16+
1317
def get_train_data(self, context: Context) -> tuple[list[str], ListOfLabels]:
1418
"""Get train data.
1519

autointent/modules/base/_regex.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""Base class for embedding modules."""
22

33
from abc import ABC
4+
from typing import Any
45

56
from autointent.modules.base import BaseModule
67

78

89
class BaseRegex(BaseModule, ABC):
910
"""Base class for rule-based modules."""
11+
12+
def get_implicit_initialization_params(self) -> dict[str, Any]:
13+
return {}

autointent/modules/scoring/_bert.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class BertScorer(BaseScorer):
2626
name = "bert"
2727
supports_multiclass = True
2828
supports_multilabel = True
29-
_model: Any
30-
_tokenizer: Any
29+
_model: Any # transformers AutoModel factory returns Any
30+
_tokenizer: Any # transformers AutoTokenizer factory returns Any
3131

3232
def __init__(
3333
self,
@@ -56,7 +56,7 @@ def from_context(
5656
seed: int = 0,
5757
) -> "BertScorer":
5858
if classification_model_config is None:
59-
classification_model_config = context.resolve_embedder()
59+
classification_model_config = context.resolve_transformer()
6060

6161
report_to = context.logging_config.report_to
6262

@@ -69,14 +69,14 @@ def from_context(
6969
report_to=report_to,
7070
)
7171

72-
def get_embedder_config(self) -> dict[str, Any]:
73-
return self.classification_model_config.model_dump()
72+
def get_implicit_initialization_params(self) -> dict[str, Any]:
73+
return {"classification_model_config": self.classification_model_config.model_dump()}
7474

75-
def __initialize_model(self) -> None:
75+
def _initialize_model(self) -> Any: # noqa: ANN401
7676
label2id = {i: i for i in range(self._n_classes)}
7777
id2label = {i: i for i in range(self._n_classes)}
7878

79-
self._model = AutoModelForSequenceClassification.from_pretrained(
79+
return AutoModelForSequenceClassification.from_pretrained(
8080
self.classification_model_config.model_name,
8181
trust_remote_code=self.classification_model_config.trust_remote_code,
8282
num_labels=self._n_classes,
@@ -96,7 +96,7 @@ def fit(
9696

9797
self._tokenizer = AutoTokenizer.from_pretrained(self.classification_model_config.model_name)
9898

99-
self.__initialize_model()
99+
self._model = self._initialize_model()
100100

101101
use_cpu = self.classification_model_config.device == "cpu"
102102

@@ -126,7 +126,7 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
126126
save_strategy="no",
127127
logging_strategy="steps",
128128
logging_steps=10,
129-
report_to=self.report_to,
129+
report_to=self.report_to if self.report_to is not None else "none",
130130
use_cpu=use_cpu,
131131
)
132132

autointent/modules/scoring/_description/description.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def from_context(
7676
Returns:
7777
Initialized DescriptionScorer instance
7878
"""
79-
if embedder_config is None:
79+
if embedder_config is None and encoder_type == "bi":
8080
embedder_config = context.resolve_embedder()
81-
if cross_encoder_config is None:
81+
if cross_encoder_config is None and encoder_type == "cross":
8282
cross_encoder_config = context.resolve_ranker()
8383

8484
return cls(
@@ -88,21 +88,13 @@ def from_context(
8888
encoder_type=encoder_type,
8989
)
9090

91-
def get_embedder_config(self) -> dict[str, Any]:
92-
"""Get the configuration of the embedder.
93-
94-
Returns:
95-
Embedder configuration
96-
"""
97-
return self.embedder_config.model_dump()
98-
99-
def get_cross_encoder_config(self) -> dict[str, Any]:
100-
"""Get the configuration of the cross-encoder.
101-
102-
Returns:
103-
Cross-encoder configuration
104-
"""
105-
return self.cross_encoder_config.model_dump()
91+
def get_implicit_initialization_params(self) -> dict[str, Any]:
92+
res = {}
93+
if self._encoder_type == "bi":
94+
res["embedder_config"] = self.embedder_config.model_dump()
95+
else:
96+
res["cross_encoder_config"] = self.cross_encoder_config.model_dump()
97+
return res
10698

10799
def fit(
108100
self,

autointent/modules/scoring/_dnnc/dnnc.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@ def from_context(
101101
cross_encoder_config=cross_encoder_config,
102102
)
103103

104+
def get_implicit_initialization_params(self) -> dict[str, Any]:
105+
return {
106+
"embedder_config": self.embedder_config.model_dump(),
107+
"cross_encoder_config": self.cross_encoder_config.model_dump(),
108+
}
109+
104110
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
105111
"""Fit the scorer by training or loading the vector index.
106112

autointent/modules/scoring/_knn/knn.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,8 @@ def from_context(
9797
weights=weights,
9898
)
9999

100-
def get_embedder_config(self) -> dict[str, Any]:
101-
"""Get the name of the embedder.
102-
103-
Returns:
104-
Embedder name
105-
"""
106-
return self.embedder_config.model_dump()
100+
def get_implicit_initialization_params(self) -> dict[str, Any]:
101+
return {"embedder_config": self.embedder_config.model_dump()}
107102

108103
def fit(self, utterances: list[str], labels: ListOfLabels, clear_cache: bool = False) -> None:
109104
"""Fit the scorer by training or loading the vector index.

0 commit comments

Comments
 (0)