Skip to content

Commit e6740f5

Browse files
ysjprojectspre-commit-ci[bot]shijie.yu
authored
adding logger args (#1973)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: shijie.yu <[email protected]>
1 parent 241bbd6 commit e6740f5

File tree

8 files changed

+70
-13
lines changed

8 files changed

+70
-13
lines changed

extensions/thunder/pretrain.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pprint
66
import sys
77
import time
8+
from dataclasses import asdict
89
from datetime import timedelta
910
from functools import partial
1011
from pathlib import Path
@@ -20,7 +21,7 @@
2021
from typing_extensions import Literal
2122

2223
from litgpt import Tokenizer
23-
from litgpt.args import EvalArgs, TrainArgs
24+
from litgpt.args import EvalArgs, LogArgs, TrainArgs
2425
from litgpt.data import DataModule, TinyLlama
2526
from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP
2627
from litgpt.utils import (
@@ -70,6 +71,7 @@ def setup(
7071
tie_embeddings=False,
7172
),
7273
eval: EvalArgs = EvalArgs(interval=1000, max_iters=100),
74+
log: LogArgs = LogArgs(),
7375
optimizer: Union[str, Dict] = "AdamW",
7476
devices: Union[int, str] = "auto",
7577
num_nodes: int = 1,
@@ -121,7 +123,12 @@ def setup(
121123
tokenizer = Tokenizer(tokenizer_dir) if tokenizer_dir is not None else None
122124

123125
logger = choose_logger(
124-
logger_name, out_dir, name=f"pretrain-{config.name}", resume=bool(resume), log_interval=train.log_interval
126+
logger_name,
127+
out_dir,
128+
name=f"pretrain-{config.name}",
129+
resume=bool(resume),
130+
log_interval=train.log_interval,
131+
log_args=asdict(log),
125132
)
126133

127134
if devices * num_nodes > 1:

litgpt/args.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,15 @@ class EvalArgs:
9090
evaluate_example: Union[str, int] = "first"
9191
"""How to pick an example instruction to evaluate periodically during training.
9292
Can be "first", "random", or an integer index to pick a specific example."""
93+
94+
95+
@dataclass
96+
class LogArgs:
97+
"""Logging-related arguments"""
98+
99+
project: Optional[str] = None
100+
"""Project name"""
101+
run: Optional[str] = None
102+
"""Run name"""
103+
group: Optional[str] = None
104+
"""Group name"""

litgpt/finetune/adapter.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torchmetrics import RunningMean
1919

2020
from litgpt.adapter import GPT, Block, Config, adapter_filter, mark_only_adapter_as_trainable
21-
from litgpt.args import EvalArgs, TrainArgs
21+
from litgpt.args import EvalArgs, LogArgs, TrainArgs
2222
from litgpt.data import Alpaca, DataModule
2323
from litgpt.generate.base import generate
2424
from litgpt.prompts import save_prompt_style
@@ -62,6 +62,7 @@ def setup(
6262
max_seq_length=None,
6363
),
6464
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
65+
log: LogArgs = LogArgs(),
6566
optimizer: Union[str, Dict] = "AdamW",
6667
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv",
6768
seed: int = 1337,
@@ -95,7 +96,13 @@ def setup(
9596
config = Config.from_file(checkpoint_dir / "model_config.yaml")
9697

9798
precision = precision or get_default_supported_precision(training=True)
98-
logger = choose_logger(logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval)
99+
logger = choose_logger(
100+
logger_name,
101+
out_dir,
102+
name=f"finetune-{config.name}",
103+
log_interval=train.log_interval,
104+
log_args=dataclasses.asdict(log),
105+
)
99106

100107
plugins = None
101108
if quantize is not None and quantize.startswith("bnb."):

litgpt/finetune/adapter_v2.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torchmetrics import RunningMean
1919

2020
from litgpt.adapter_v2 import GPT, Block, Config, adapter_filter, mark_only_adapter_v2_as_trainable
21-
from litgpt.args import EvalArgs, TrainArgs
21+
from litgpt.args import EvalArgs, LogArgs, TrainArgs
2222
from litgpt.data import Alpaca, DataModule
2323
from litgpt.generate.base import generate
2424
from litgpt.prompts import save_prompt_style
@@ -64,6 +64,7 @@ def setup(
6464
max_seq_length=None,
6565
),
6666
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
67+
log: LogArgs = LogArgs(),
6768
optimizer: Union[str, Dict] = "AdamW",
6869
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv",
6970
seed: int = 1337,
@@ -97,7 +98,13 @@ def setup(
9798
config = Config.from_file(checkpoint_dir / "model_config.yaml")
9899

99100
precision = precision or get_default_supported_precision(training=True)
100-
logger = choose_logger(logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval)
101+
logger = choose_logger(
102+
logger_name,
103+
out_dir,
104+
name=f"finetune-{config.name}",
105+
log_interval=train.log_interval,
106+
log_args=dataclasses.asdict(log),
107+
)
101108

102109
plugins = None
103110
if quantize is not None and quantize.startswith("bnb."):

litgpt/finetune/full.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch.utils.data import ConcatDataset, DataLoader
1414
from torchmetrics import RunningMean
1515

16-
from litgpt.args import EvalArgs, TrainArgs
16+
from litgpt.args import EvalArgs, LogArgs, TrainArgs
1717
from litgpt.data import Alpaca, DataModule
1818
from litgpt.generate.base import generate
1919
from litgpt.model import GPT, Block, Config
@@ -58,6 +58,7 @@ def setup(
5858
max_seq_length=None,
5959
),
6060
eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100),
61+
log: LogArgs = LogArgs(),
6162
optimizer: Union[str, Dict] = "AdamW",
6263
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv",
6364
seed: int = 1337,
@@ -94,7 +95,12 @@ def setup(
9495

9596
precision = precision or get_default_supported_precision(training=True)
9697
logger = choose_logger(
97-
logger_name, out_dir, name=f"finetune-{config.name}", resume=bool(resume), log_interval=train.log_interval
98+
logger_name,
99+
out_dir,
100+
name=f"finetune-{config.name}",
101+
resume=bool(resume),
102+
log_interval=train.log_interval,
103+
log_args=dataclasses.asdict(log),
98104
)
99105

100106
if devices * num_nodes > 1:

litgpt/finetune/lora.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch.utils.data import ConcatDataset, DataLoader
1818
from torchmetrics import RunningMean
1919

20-
from litgpt.args import EvalArgs, TrainArgs
20+
from litgpt.args import EvalArgs, LogArgs, TrainArgs
2121
from litgpt.data import Alpaca, DataModule
2222
from litgpt.generate.base import generate
2323
from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable
@@ -71,6 +71,7 @@ def setup(
7171
epochs=5,
7272
max_seq_length=None,
7373
),
74+
log: LogArgs = LogArgs(),
7475
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
7576
optimizer: Union[str, Dict] = "AdamW",
7677
logger_name: Literal["wandb", "tensorboard", "csv", "mlflow"] = "csv",
@@ -125,7 +126,13 @@ def setup(
125126
)
126127

127128
precision = precision or get_default_supported_precision(training=True)
128-
logger = choose_logger(logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval)
129+
logger = choose_logger(
130+
logger_name,
131+
out_dir,
132+
name=f"finetune-{config.name}",
133+
log_interval=train.log_interval,
134+
log_args=dataclasses.asdict(log),
135+
)
129136

130137
plugins = None
131138
if quantize is not None and quantize.startswith("bnb."):

litgpt/pretrain.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import math
44
import pprint
55
import time
6+
from dataclasses import asdict
67
from datetime import timedelta
78
from functools import partial
89
from pathlib import Path
@@ -18,7 +19,7 @@
1819
from typing_extensions import Literal
1920

2021
from litgpt import Tokenizer
21-
from litgpt.args import EvalArgs, TrainArgs
22+
from litgpt.args import EvalArgs, LogArgs, TrainArgs
2223
from litgpt.config import name_to_config
2324
from litgpt.data import DataModule, TinyLlama
2425
from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP
@@ -62,6 +63,7 @@ def setup(
6263
tie_embeddings=False,
6364
),
6465
eval: EvalArgs = EvalArgs(interval=1000, max_iters=100),
66+
log: LogArgs = LogArgs(),
6567
optimizer: Union[str, Dict] = "AdamW",
6668
devices: Union[int, str] = "auto",
6769
num_nodes: int = 1,
@@ -127,7 +129,12 @@ def setup(
127129
tokenizer = Tokenizer(tokenizer_dir) if tokenizer_dir is not None else None
128130

129131
logger = choose_logger(
130-
logger_name, out_dir, name=f"pretrain-{config.name}", resume=bool(resume), log_interval=train.log_interval
132+
logger_name,
133+
out_dir,
134+
name=f"pretrain-{config.name}",
135+
resume=bool(resume),
136+
log_interval=train.log_interval,
137+
log_args=asdict(log),
131138
)
132139

133140
if devices * num_nodes > 1:

litgpt/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ def choose_logger(
542542
out_dir: Path,
543543
name: str,
544544
log_interval: int = 1,
545+
log_args: Optional[Dict] = None,
545546
resume: Optional[bool] = None,
546547
**kwargs: Any,
547548
):
@@ -550,7 +551,10 @@ def choose_logger(
550551
if logger_name == "tensorboard":
551552
return TensorBoardLogger(root_dir=(out_dir / "logs"), name="tensorboard", **kwargs)
552553
if logger_name == "wandb":
553-
return WandbLogger(project=name, resume=resume, **kwargs)
554+
project = log_args.pop("project", name)
555+
run = log_args.pop("run", os.environ.get("WANDB_RUN_NAME"))
556+
group = log_args.pop("group", os.environ.get("WANDB_RUN_GROUP"))
557+
return WandbLogger(project=project, name=run, group=group, resume=resume, **kwargs)
554558
if logger_name == "mlflow":
555559
return MLFlowLogger(experiment_name=name, **kwargs)
556560
raise ValueError(f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'.")

0 commit comments

Comments
 (0)