|
18 | 18 | from torchmetrics import RunningMean
|
19 | 19 |
|
20 | 20 | from litgpt.adapter_v2 import GPT, Block, Config, adapter_filter, mark_only_adapter_v2_as_trainable
|
21 |
| -from litgpt.args import EvalArgs, TrainArgs |
| 21 | +from litgpt.args import EvalArgs, LogArgs, TrainArgs |
22 | 22 | from litgpt.data import Alpaca, DataModule
|
23 | 23 | from litgpt.generate.base import generate
|
24 | 24 | from litgpt.prompts import save_prompt_style
|
@@ -64,6 +64,7 @@ def setup(
|
64 | 64 | max_seq_length=None,
|
65 | 65 | ),
|
66 | 66 | eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
|
| 67 | + log: LogArgs = LogArgs(), |
67 | 68 | optimizer: Union[str, Dict] = "AdamW",
|
68 | 69 | logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv",
|
69 | 70 | seed: int = 1337,
|
@@ -97,7 +98,13 @@ def setup(
|
97 | 98 | config = Config.from_file(checkpoint_dir / "model_config.yaml")
|
98 | 99 |
|
99 | 100 | precision = precision or get_default_supported_precision(training=True)
|
100 |
| - logger = choose_logger(logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval) |
| 101 | + logger = choose_logger( |
| 102 | + logger_name, |
| 103 | + out_dir, |
| 104 | + name=f"finetune-{config.name}", |
| 105 | + log_interval=train.log_interval, |
| 106 | + log_args=dataclasses.asdict(log), |
| 107 | + ) |
101 | 108 |
|
102 | 109 | plugins = None
|
103 | 110 | if quantize is not None and quantize.startswith("bnb."):
|
|
0 commit comments