Skip to content

Commit 5196a3e

Browse files
committed
try to supress
1 parent 328ca2b commit 5196a3e

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

autointent/modules/scoring/_bert.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,20 @@ def __init__(
9696
self.seed = seed
9797
self.report_to = report_to
9898
self.early_stopping_config = early_stopping_config or EarlyStoppingConfig()
99-
self.training_arguments = training_arguments
99+
self.training_arguments = training_arguments or {}
100+
# init here for faster validation
101+
self.training_args = TrainingArguments(
102+
num_train_epochs=self.num_train_epochs,
103+
per_device_train_batch_size=self.batch_size,
104+
learning_rate=self.learning_rate,
105+
seed=self.seed,
106+
report_to=self.report_to if self.report_to is not None else "none",
107+
use_cpu=self.classification_model_config.device == "cpu",
108+
metric_for_best_model=self.early_stopping_config.metric,
109+
load_best_model_at_end=self.early_stopping_config.metric is not None,
110+
**self.training_arguments,
111+
)
112+
100113

101114
@classmethod
102115
def from_context(
@@ -163,22 +176,11 @@ def _train(self, tokenized_dataset: DatasetDict) -> None:
163176
tokenized_dataset: output from :py:meth:`BertScorer._get_tokenized_dataset`
164177
"""
165178
with tempfile.TemporaryDirectory() as tmp_dir:
166-
training_args = TrainingArguments(
167-
output_dir=tmp_dir,
168-
num_train_epochs=self.num_train_epochs,
169-
per_device_train_batch_size=self.batch_size,
170-
learning_rate=self.learning_rate,
171-
seed=self.seed,
172-
report_to=self.report_to if self.report_to is not None else "none",
173-
use_cpu=self.classification_model_config.device == "cpu",
174-
metric_for_best_model=self.early_stopping_config.metric,
175-
load_best_model_at_end=self.early_stopping_config.metric is not None,
176-
**self.training_arguments,
177-
)
179+
self.training_args.output_dir = tmp_dir
178180

179181
trainer = Trainer( # type: ignore[no-untyped-call]
180182
model=self._model,
181-
args=training_args,
183+
args=self.training_args,
182184
train_dataset=tokenized_dataset["train"],
183185
eval_dataset=tokenized_dataset["validation"],
184186
processing_class=self._tokenizer,
@@ -187,7 +189,7 @@ def _train(self, tokenized_dataset: DatasetDict) -> None:
187189
callbacks=self._get_trainer_callbacks(),
188190
)
189191

190-
trainer.train() # type: ignore[attr-defined]
192+
_ = trainer.train() # type: ignore[attr-defined]
191193

192194
def _get_trainer_callbacks(self) -> list[TrainerCallback]:
193195
res: list[TrainerCallback] = []

0 commit comments

Comments
 (0)