Skip to content

Commit c1237a5

Browse files
authored
Chronos-2: Add option to remove PrinterCallback (#410)
*Issue #, if available:* *Description of changes:* By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent 1da6965 commit c1237a5

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

src/chronos/chronos2/pipeline.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from transformers.utils.import_utils import is_peft_available
2020
from transformers.utils.peft_utils import find_adapter_config_file
2121

22-
2322
import chronos.chronos2
2423
from chronos.base import BaseChronosPipeline, ForecastType
2524
from chronos.chronos2 import Chronos2Model
@@ -114,6 +113,7 @@ def fit(
114113
min_past: int | None = None,
115114
finetuned_ckpt_name: str = "finetuned-ckpt",
116115
callbacks: list["TrainerCallback"] | None = None,
116+
remove_printer_callback: bool = False,
117117
**extra_trainer_kwargs,
118118
) -> "Chronos2Pipeline":
119119
"""
@@ -156,6 +156,8 @@ def fit(
156156
The name of the directory inside `output_dir` in which the final fine-tuned checkpoint will be saved, by default "finetuned-ckpt"
157157
callbacks
158158
A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer`
159+
remove_printer_callback
160+
If True, all instances of `PrinterCallback` are removed from callbacks
159161
**extra_trainer_kwargs
160162
Extra kwargs are directly forwarded to `TrainingArguments`
161163
@@ -165,6 +167,7 @@ def fit(
165167
"""
166168

167169
import torch.cuda
170+
from transformers.trainer_callback import PrinterCallback
168171
from transformers.training_args import TrainingArguments
169172

170173
if finetune_mode == "lora":
@@ -322,6 +325,10 @@ def fit(
322325
eval_dataset=eval_dataset,
323326
callbacks=callbacks,
324327
)
328+
329+
if remove_printer_callback:
330+
trainer.pop_callback(PrinterCallback)
331+
325332
trainer.train()
326333

327334
# update max_output_patches, if the model was fine-tuned with longer prediction_length

0 commit comments

Comments
 (0)