@@ -96,7 +96,20 @@ def __init__(
9696 self .seed = seed
9797 self .report_to = report_to
9898 self .early_stopping_config = early_stopping_config or EarlyStoppingConfig ()
99- self .training_arguments = training_arguments
99+ self .training_arguments = training_arguments or {}
100+ # init here for faster validation
101+ self .training_args = TrainingArguments (
102+ num_train_epochs = self .num_train_epochs ,
103+ per_device_train_batch_size = self .batch_size ,
104+ learning_rate = self .learning_rate ,
105+ seed = self .seed ,
106+ report_to = self .report_to if self .report_to is not None else "none" ,
107+ use_cpu = self .classification_model_config .device == "cpu" ,
108+ metric_for_best_model = self .early_stopping_config .metric ,
109+ load_best_model_at_end = self .early_stopping_config .metric is not None ,
110+ ** self .training_arguments ,
111+ )
112+
100113
101114 @classmethod
102115 def from_context (
@@ -163,22 +176,11 @@ def _train(self, tokenized_dataset: DatasetDict) -> None:
163176 tokenized_dataset: output from :py:meth:`BertScorer._get_tokenized_dataset`
164177 """
165178 with tempfile .TemporaryDirectory () as tmp_dir :
166- training_args = TrainingArguments (
167- output_dir = tmp_dir ,
168- num_train_epochs = self .num_train_epochs ,
169- per_device_train_batch_size = self .batch_size ,
170- learning_rate = self .learning_rate ,
171- seed = self .seed ,
172- report_to = self .report_to if self .report_to is not None else "none" ,
173- use_cpu = self .classification_model_config .device == "cpu" ,
174- metric_for_best_model = self .early_stopping_config .metric ,
175- load_best_model_at_end = self .early_stopping_config .metric is not None ,
176- ** self .training_arguments ,
177- )
179+ self .training_args .output_dir = tmp_dir
178180
179181 trainer = Trainer ( # type: ignore[no-untyped-call]
180182 model = self ._model ,
181- args = training_args ,
183+ args = self . training_args ,
182184 train_dataset = tokenized_dataset ["train" ],
183185 eval_dataset = tokenized_dataset ["validation" ],
184186 processing_class = self ._tokenizer ,
@@ -187,7 +189,7 @@ def _train(self, tokenized_dataset: DatasetDict) -> None:
187189 callbacks = self ._get_trainer_callbacks (),
188190 )
189191
190- trainer .train () # type: ignore[attr-defined]
192+ _ = trainer .train () # type: ignore[attr-defined]
191193
192194 def _get_trainer_callbacks (self ) -> list [TrainerCallback ]:
193195 res : list [TrainerCallback ] = []
0 commit comments