Skip to content

Commit bcdf0f2

Browse files
committed
implement lazy importing transformers
1 parent 104378f commit bcdf0f2

File tree

5 files changed

+67
-51
lines changed

5 files changed

+67
-51
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ dependencies = [
5252
[project.optional-dependencies]
5353
catboost = ["catboost (>=1.2.8,<2.0.0)"]
5454
peft = ["peft (>= 0.10.0, !=0.15.0, !=0.15.1, <1.0.0)"]
55+
transformers = ["transformers[torch] (>=4.49.0,<5.0.0)"]
5556
dspy = [
5657
"dspy (>=2.6.5,<3.0.0)",
5758
]

src/autointent/_dump_tools/unit_dumpers.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,6 @@
1010
import numpy.typing as npt
1111
from pydantic import BaseModel
1212
from sklearn.base import BaseEstimator
13-
from transformers import (
14-
AutoModelForSequenceClassification,
15-
AutoTokenizer,
16-
PreTrainedModel,
17-
PreTrainedTokenizer,
18-
PreTrainedTokenizerFast,
19-
)
2013

2114
from autointent import Embedder, Ranker, VectorIndex
2215
from autointent._utils import require
@@ -28,6 +21,7 @@
2821
if TYPE_CHECKING:
2922
from catboost import CatBoostClassifier
3023
from peft import PeftModel
24+
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
3125

3226
T = TypeVar("T")
3327
logger = logging.getLogger(__name__)
@@ -223,21 +217,22 @@ def dump(obj: "PeftModel", path: Path, exists_ok: bool) -> None:
223217
# strategy to save lora models: merge adapters and save as usual hugging face model
224218
lora_path = path / "lora"
225219
lora_path.mkdir(parents=True, exist_ok=exists_ok)
226-
merged_model: PreTrainedModel = obj.merge_and_unload()
220+
merged_model: "PreTrainedModel" = obj.merge_and_unload()
227221
merged_model.save_pretrained(lora_path)
228222

229223
@staticmethod
230224
def load(path: Path, **kwargs: Any) -> "PeftModel": # noqa: ANN401, ARG004
231225
peft = require("peft", extra="peft")
226+
transformers = require("transformers", extra="transformers")
232227
if (path / "ptuning").exists():
233228
# prompt learning model
234229
ptuning_path = path / "ptuning"
235-
model = AutoModelForSequenceClassification.from_pretrained(ptuning_path / "base_model")
230+
model = transformers.AutoModelForSequenceClassification.from_pretrained(ptuning_path / "base_model")
236231
return peft.PeftModel.from_pretrained(model, ptuning_path / "peft")
237232
if (path / "lora").exists():
238233
# merged lora model
239234
lora_path = path / "lora"
240-
return AutoModelForSequenceClassification.from_pretrained(lora_path) # type: ignore[no-any-return]
235+
return transformers.AutoModelForSequenceClassification.from_pretrained(lora_path) # type: ignore[no-any-return]
241236
msg = f"Invalid PeftModel directory structure at {path}. Expected 'ptuning' or 'lora' subdirectory."
242237
raise ValueError(msg)
243238

@@ -250,38 +245,48 @@ def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
250245
return False
251246

252247

253-
class HFModelDumper(BaseObjectDumper[PreTrainedModel]):
248+
class HFModelDumper(BaseObjectDumper["PreTrainedModel"]):
254249
dir_or_file_name = "hf_models"
255250

256251
@staticmethod
257-
def dump(obj: PreTrainedModel, path: Path, exists_ok: bool) -> None:
252+
def dump(obj: "PreTrainedModel", path: Path, exists_ok: bool) -> None:
258253
path.mkdir(parents=True, exist_ok=exists_ok)
259254
obj.save_pretrained(path)
260255

261256
@staticmethod
262-
def load(path: Path, **kwargs: Any) -> PreTrainedModel: # noqa: ANN401, ARG004
263-
return AutoModelForSequenceClassification.from_pretrained(path) # type: ignore[no-any-return]
257+
def load(path: Path, **kwargs: Any) -> "PreTrainedModel": # noqa: ANN401, ARG004
258+
transformers = require("transformers", extra="transformers")
259+
return transformers.AutoModelForSequenceClassification.from_pretrained(path) # type: ignore[no-any-return]
264260

265261
@classmethod
266262
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
267-
return isinstance(obj, PreTrainedModel)
263+
try:
264+
transformers = require("transformers", extra="transformers")
265+
return isinstance(obj, transformers.PreTrainedModel)
266+
except ImportError:
267+
return False
268268

269269

270-
class HFTokenizerDumper(BaseObjectDumper[PreTrainedTokenizer | PreTrainedTokenizerFast]):
270+
class HFTokenizerDumper(BaseObjectDumper["PreTrainedTokenizer | PreTrainedTokenizerFast"]):
271271
dir_or_file_name = "hf_tokenizers"
272272

273273
@staticmethod
274-
def dump(obj: PreTrainedTokenizer | PreTrainedTokenizerFast, path: Path, exists_ok: bool) -> None:
274+
def dump(obj: "PreTrainedTokenizer | PreTrainedTokenizerFast", path: Path, exists_ok: bool) -> None:
275275
path.mkdir(parents=True, exist_ok=exists_ok)
276276
obj.save_pretrained(path)
277277

278278
@staticmethod
279-
def load(path: Path, **kwargs: Any) -> PreTrainedTokenizer | PreTrainedTokenizerFast: # noqa: ANN401, ARG004
280-
return AutoTokenizer.from_pretrained(path) # type: ignore[no-any-return,no-untyped-call]
279+
def load(path: Path, **kwargs: Any) -> "PreTrainedTokenizer | PreTrainedTokenizerFast": # noqa: ANN401, ARG004
280+
transformers = require("transformers", extra="transformers")
281+
return transformers.AutoTokenizer.from_pretrained(path) # type: ignore[no-any-return,no-untyped-call]
281282

282283
@classmethod
283284
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
284-
return isinstance(obj, PreTrainedTokenizer | PreTrainedTokenizerFast)
285+
try:
286+
transformers = require("transformers", extra="transformers")
287+
return isinstance(obj, transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast)
288+
except ImportError:
289+
return False
285290

286291

287292
class TorchModelDumper(BaseObjectDumper[BaseTorchModule]):

src/autointent/_wrappers/embedder/sentence_transformers.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import tempfile
33
from functools import lru_cache
44
from pathlib import Path
5-
from typing import Literal, cast, overload
5+
from typing import TYPE_CHECKING, Literal, cast, overload
66
from uuid import uuid4
77

88
import huggingface_hub
@@ -14,16 +14,19 @@
1414
from sentence_transformers.losses import BatchAllTripletLoss
1515
from sentence_transformers.training_args import BatchSamplers
1616
from sklearn.model_selection import train_test_split
17-
from transformers import EarlyStoppingCallback, TrainerCallback
1817

1918
from autointent._hash import Hasher
19+
from autointent._utils import require
2020
from autointent.configs import EmbedderFineTuningConfig, TaskTypeEnum
2121
from autointent.configs._embedder import SentenceTransformerEmbeddingConfig
2222
from autointent.custom_types import ListOfLabels
2323

2424
from .base import BaseEmbeddingBackend
2525
from .utils import get_embeddings_path
2626

27+
if TYPE_CHECKING:
28+
from transformers import TrainerCallback
29+
2730
logger = logging.getLogger(__name__)
2831

2932

@@ -234,6 +237,9 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
234237

235238
loss = BatchAllTripletLoss(model=model, margin=config.margin)
236239
with tempfile.TemporaryDirectory() as tmp_dir:
240+
# Lazy import transformers (only needed for fine-tuning)
241+
transformers = require("transformers", extra="transformers")
242+
237243
args = SentenceTransformerTrainingArguments(
238244
save_strategy="epoch",
239245
save_total_limit=1,
@@ -251,8 +257,8 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
251257
eval_strategy="epoch",
252258
greater_is_better=False,
253259
)
254-
callbacks: list[TrainerCallback] = [
255-
EarlyStoppingCallback(
260+
callbacks: list["TrainerCallback"] = [
261+
transformers.EarlyStoppingCallback(
256262
early_stopping_patience=config.early_stopping_patience,
257263
early_stopping_threshold=config.early_stopping_threshold,
258264
)

src/autointent/context/data_handler/_stratification.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import logging
8+
import random
89
from collections.abc import Sequence
910

1011
import numpy as np
@@ -13,7 +14,6 @@
1314
from numpy import typing as npt
1415
from sklearn.model_selection import train_test_split
1516
from skmultilearn.model_selection import IterativeStratification
16-
from transformers import set_seed
1717

1818
from autointent import Dataset
1919
from autointent.custom_types import LabelType
@@ -156,7 +156,8 @@ def _split_multilabel(self, dataset: HFDataset, test_size: float) -> Sequence[np
156156
A sequence containing indices for train and test splits.
157157
"""
158158
if self.random_seed is not None:
159-
set_seed(self.random_seed) # workaround for buggy nature of IterativeStratification from skmultilearn
159+
# Set all seeds for reproducibility (workaround for buggy nature of IterativeStratification from skmultilearn)
160+
random.seed(self.random_seed)
160161
splitter = IterativeStratification(
161162
n_splits=2,
162163
order=2,

src/autointent/modules/scoring/_bert.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,25 @@
22

33
import tempfile
44
from collections.abc import Callable
5-
from typing import Any, Literal
5+
from typing import TYPE_CHECKING, Any, Literal
66

77
import numpy as np
88
import numpy.typing as npt
99
import torch
1010
from datasets import Dataset, DatasetDict
1111
from sklearn.model_selection import train_test_split
12-
from transformers import (
13-
AutoModelForSequenceClassification,
14-
AutoTokenizer,
15-
DataCollatorWithPadding,
16-
EarlyStoppingCallback,
17-
EvalPrediction,
18-
PrinterCallback,
19-
ProgressCallback,
20-
Trainer,
21-
TrainingArguments,
22-
)
23-
from transformers.trainer_callback import TrainerCallback
2412

2513
from autointent import Context
2614
from autointent._callbacks import REPORTERS_NAMES
15+
from autointent._utils import require
2716
from autointent.configs import EarlyStoppingConfig, HFModelConfig
2817
from autointent.custom_types import ListOfLabels
2918
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
3019
from autointent.modules.base import BaseScorer
3120

21+
if TYPE_CHECKING:
22+
from transformers import EvalPrediction, TrainerCallback
23+
3224

3325
class BertScorer(BaseScorer):
3426
"""Scoring module for transformer-based classification using BERT models.
@@ -90,6 +82,17 @@ def __init__(
9082
early_stopping_config: EarlyStoppingConfig | dict[str, Any] | None = None,
9183
print_progress: bool = False,
9284
) -> None:
85+
# Lazy import transformers
86+
transformers = require("transformers", extra="transformers")
87+
self._AutoModelForSequenceClassification = transformers.AutoModelForSequenceClassification
88+
self._AutoTokenizer = transformers.AutoTokenizer
89+
self._DataCollatorWithPadding = transformers.DataCollatorWithPadding
90+
self._EarlyStoppingCallback = transformers.EarlyStoppingCallback
91+
self._PrinterCallback = transformers.PrinterCallback
92+
self._ProgressCallback = transformers.ProgressCallback
93+
self._Trainer = transformers.Trainer
94+
self._TrainingArguments = transformers.TrainingArguments
95+
9396
self.classification_model_config = HFModelConfig.from_search_config(classification_model_config)
9497
self.num_train_epochs = num_train_epochs
9598
self.batch_size = batch_size
@@ -132,7 +135,7 @@ def _initialize_model(self) -> Any: # noqa: ANN401
132135
label2id = {i: i for i in range(self._n_classes)}
133136
id2label = {i: i for i in range(self._n_classes)}
134137

135-
return AutoModelForSequenceClassification.from_pretrained(
138+
return self._AutoModelForSequenceClassification.from_pretrained(
136139
self.classification_model_config.model_name,
137140
trust_remote_code=self.classification_model_config.trust_remote_code,
138141
num_labels=self._n_classes,
@@ -148,7 +151,7 @@ def fit(
148151
) -> None:
149152
self._validate_task(labels)
150153

151-
self._tokenizer = AutoTokenizer.from_pretrained(self.classification_model_config.model_name) # type: ignore[no-untyped-call]
154+
self._tokenizer = self._AutoTokenizer.from_pretrained(self.classification_model_config.model_name) # type: ignore[no-untyped-call]
152155
self._model = self._initialize_model()
153156
tokenized_dataset = self._get_tokenized_dataset(utterances, labels)
154157
self._train(tokenized_dataset)
@@ -162,7 +165,7 @@ def _train(self, tokenized_dataset: DatasetDict) -> None:
162165
tokenized_dataset: output from :py:meth:`BertScorer._get_tokenized_dataset`
163166
"""
164167
with tempfile.TemporaryDirectory() as tmp_dir:
165-
training_args = TrainingArguments(
168+
training_args = self._TrainingArguments(
166169
output_dir=tmp_dir,
167170
num_train_epochs=self.num_train_epochs,
168171
per_device_train_batch_size=self.batch_size,
@@ -181,27 +184,27 @@ def _train(self, tokenized_dataset: DatasetDict) -> None:
181184
load_best_model_at_end=self.early_stopping_config.metric is not None,
182185
)
183186

184-
trainer = Trainer(
187+
trainer = self._Trainer(
185188
model=self._model,
186189
args=training_args,
187190
train_dataset=tokenized_dataset["train"],
188191
eval_dataset=tokenized_dataset["validation"],
189192
processing_class=self._tokenizer,
190-
data_collator=DataCollatorWithPadding(tokenizer=self._tokenizer),
193+
data_collator=self._DataCollatorWithPadding(tokenizer=self._tokenizer),
191194
compute_metrics=self._get_compute_metrics(),
192195
callbacks=self._get_trainer_callbacks(),
193196
)
194197
if not self.print_progress:
195-
trainer.remove_callback(PrinterCallback)
196-
trainer.remove_callback(ProgressCallback)
198+
trainer.remove_callback(self._PrinterCallback)
199+
trainer.remove_callback(self._ProgressCallback)
197200

198201
trainer.train()
199202

200-
def _get_trainer_callbacks(self) -> list[TrainerCallback]:
201-
res: list[TrainerCallback] = []
203+
def _get_trainer_callbacks(self) -> list["TrainerCallback"]:
204+
res: list["TrainerCallback"] = []
202205
if self.early_stopping_config.metric is not None:
203206
res.append(
204-
EarlyStoppingCallback(
207+
self._EarlyStoppingCallback(
205208
early_stopping_patience=self.early_stopping_config.patience,
206209
early_stopping_threshold=self.early_stopping_config.threshold,
207210
)
@@ -235,7 +238,7 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
235238

236239
return dataset.map(tokenize_function, batched=True, batch_size=self.batch_size)
237240

238-
def _get_compute_metrics(self) -> Callable[[EvalPrediction], dict[str, float]] | None:
241+
def _get_compute_metrics(self) -> Callable[["EvalPrediction"], dict[str, float]] | None:
239242
"""Construct callable for computing metrics during transformer training.
240243
241244
The result of this function is supposed to pass to :py:class:`transformers.Trainer`.
@@ -246,7 +249,7 @@ def _get_compute_metrics(self) -> Callable[[EvalPrediction], dict[str, float]] |
246249
metric_name = self.early_stopping_config.metric
247250
metric_fn = (SCORING_METRICS_MULTILABEL | SCORING_METRICS_MULTICLASS)[metric_name]
248251

249-
def compute_metrics(output: EvalPrediction) -> dict[str, float]:
252+
def compute_metrics(output: "EvalPrediction") -> dict[str, float]:
250253
return {
251254
metric_name: metric_fn(output.label_ids.tolist(), output.predictions.tolist()) # type: ignore[union-attr]
252255
}

0 commit comments

Comments
 (0)