22
33import tempfile
44from collections .abc import Callable
5- from typing import Any , Literal
5+ from typing import TYPE_CHECKING , Any , Literal
66
77import numpy as np
88import numpy .typing as npt
99import torch
1010from datasets import Dataset , DatasetDict
1111from sklearn .model_selection import train_test_split
12- from transformers import (
13- AutoModelForSequenceClassification ,
14- AutoTokenizer ,
15- DataCollatorWithPadding ,
16- EarlyStoppingCallback ,
17- EvalPrediction ,
18- PrinterCallback ,
19- ProgressCallback ,
20- Trainer ,
21- TrainingArguments ,
22- )
23- from transformers .trainer_callback import TrainerCallback
2412
2513from autointent import Context
2614from autointent ._callbacks import REPORTERS_NAMES
15+ from autointent ._utils import require
2716from autointent .configs import EarlyStoppingConfig , HFModelConfig
2817from autointent .custom_types import ListOfLabels
2918from autointent .metrics import SCORING_METRICS_MULTICLASS , SCORING_METRICS_MULTILABEL
3019from autointent .modules .base import BaseScorer
3120
21+ if TYPE_CHECKING :
22+ from transformers import EvalPrediction , TrainerCallback
23+
3224
3325class BertScorer (BaseScorer ):
3426 """Scoring module for transformer-based classification using BERT models.
@@ -90,6 +82,17 @@ def __init__(
9082 early_stopping_config : EarlyStoppingConfig | dict [str , Any ] | None = None ,
9183 print_progress : bool = False ,
9284 ) -> None :
85+ # Lazy import transformers
86+ transformers = require ("transformers" , extra = "transformers" )
87+ self ._AutoModelForSequenceClassification = transformers .AutoModelForSequenceClassification
88+ self ._AutoTokenizer = transformers .AutoTokenizer
89+ self ._DataCollatorWithPadding = transformers .DataCollatorWithPadding
90+ self ._EarlyStoppingCallback = transformers .EarlyStoppingCallback
91+ self ._PrinterCallback = transformers .PrinterCallback
92+ self ._ProgressCallback = transformers .ProgressCallback
93+ self ._Trainer = transformers .Trainer
94+ self ._TrainingArguments = transformers .TrainingArguments
95+
9396 self .classification_model_config = HFModelConfig .from_search_config (classification_model_config )
9497 self .num_train_epochs = num_train_epochs
9598 self .batch_size = batch_size
@@ -132,7 +135,7 @@ def _initialize_model(self) -> Any: # noqa: ANN401
132135 label2id = {i : i for i in range (self ._n_classes )}
133136 id2label = {i : i for i in range (self ._n_classes )}
134137
135- return AutoModelForSequenceClassification .from_pretrained (
138+ return self . _AutoModelForSequenceClassification .from_pretrained (
136139 self .classification_model_config .model_name ,
137140 trust_remote_code = self .classification_model_config .trust_remote_code ,
138141 num_labels = self ._n_classes ,
@@ -148,7 +151,7 @@ def fit(
148151 ) -> None :
149152 self ._validate_task (labels )
150153
151- self ._tokenizer = AutoTokenizer .from_pretrained (self .classification_model_config .model_name ) # type: ignore[no-untyped-call]
154+ self ._tokenizer = self . _AutoTokenizer .from_pretrained (self .classification_model_config .model_name ) # type: ignore[no-untyped-call]
152155 self ._model = self ._initialize_model ()
153156 tokenized_dataset = self ._get_tokenized_dataset (utterances , labels )
154157 self ._train (tokenized_dataset )
@@ -162,7 +165,7 @@ def _train(self, tokenized_dataset: DatasetDict) -> None:
162165 tokenized_dataset: output from :py:meth:`BertScorer._get_tokenized_dataset`
163166 """
164167 with tempfile .TemporaryDirectory () as tmp_dir :
165- training_args = TrainingArguments (
168+ training_args = self . _TrainingArguments (
166169 output_dir = tmp_dir ,
167170 num_train_epochs = self .num_train_epochs ,
168171 per_device_train_batch_size = self .batch_size ,
@@ -181,27 +184,27 @@ def _train(self, tokenized_dataset: DatasetDict) -> None:
181184 load_best_model_at_end = self .early_stopping_config .metric is not None ,
182185 )
183186
184- trainer = Trainer (
187+ trainer = self . _Trainer (
185188 model = self ._model ,
186189 args = training_args ,
187190 train_dataset = tokenized_dataset ["train" ],
188191 eval_dataset = tokenized_dataset ["validation" ],
189192 processing_class = self ._tokenizer ,
190- data_collator = DataCollatorWithPadding (tokenizer = self ._tokenizer ),
193+ data_collator = self . _DataCollatorWithPadding (tokenizer = self ._tokenizer ),
191194 compute_metrics = self ._get_compute_metrics (),
192195 callbacks = self ._get_trainer_callbacks (),
193196 )
194197 if not self .print_progress :
195- trainer .remove_callback (PrinterCallback )
196- trainer .remove_callback (ProgressCallback )
198+ trainer .remove_callback (self . _PrinterCallback )
199+ trainer .remove_callback (self . _ProgressCallback )
197200
198201 trainer .train ()
199202
200- def _get_trainer_callbacks (self ) -> list [TrainerCallback ]:
201- res : list [TrainerCallback ] = []
203+ def _get_trainer_callbacks (self ) -> list [" TrainerCallback" ]:
204+ res : list [" TrainerCallback" ] = []
202205 if self .early_stopping_config .metric is not None :
203206 res .append (
204- EarlyStoppingCallback (
207+ self . _EarlyStoppingCallback (
205208 early_stopping_patience = self .early_stopping_config .patience ,
206209 early_stopping_threshold = self .early_stopping_config .threshold ,
207210 )
@@ -235,7 +238,7 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
235238
236239 return dataset .map (tokenize_function , batched = True , batch_size = self .batch_size )
237240
238- def _get_compute_metrics (self ) -> Callable [[EvalPrediction ], dict [str , float ]] | None :
241+ def _get_compute_metrics (self ) -> Callable [[" EvalPrediction" ], dict [str , float ]] | None :
239242 """Construct callable for computing metrics during transformer training.
240243
241244 The result of this function is supposed to pass to :py:class:`transformers.Trainer`.
@@ -246,7 +249,7 @@ def _get_compute_metrics(self) -> Callable[[EvalPrediction], dict[str, float]] |
246249 metric_name = self .early_stopping_config .metric
247250 metric_fn = (SCORING_METRICS_MULTILABEL | SCORING_METRICS_MULTICLASS )[metric_name ]
248251
249- def compute_metrics (output : EvalPrediction ) -> dict [str , float ]:
252+ def compute_metrics (output : " EvalPrediction" ) -> dict [str , float ]:
250253 return {
251254 metric_name : metric_fn (output .label_ids .tolist (), output .predictions .tolist ()) # type: ignore[union-attr]
252255 }
0 commit comments