Skip to content

Commit 96a0e82

Browse files
committed
feat: add custom logger
1 parent fb18a5e commit 96a0e82

File tree

2 files changed

+45
-13
lines changed

2 files changed

+45
-13
lines changed

open_diloco/train_fsdp.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from pydantic import model_validator
1717
import torch
18-
import wandb
1918
from pydantic_config import parse_argv, BaseConfig
2019
from datasets import load_dataset
2120
from datasets.distributed import split_dataset_by_node
@@ -35,7 +34,6 @@
3534
MixedPrecision,
3635
)
3736
from torch.distributed.device_mesh import DeviceMesh
38-
from torch.distributed import broadcast_object_list
3937
from open_diloco.ckpt_utils import (
4038
CKPT_PREFIX,
4139
CkptConfig,
@@ -46,7 +44,7 @@
4644
save_checkpoint,
4745
)
4846
from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer
49-
47+
from open_diloco.utils import WandbLogger, DummyLogger
5048

5149
from hivemind.dht.dht import DHT
5250
from hivemind.utils.networking import log_visible_maddrs
@@ -120,6 +118,7 @@ class Config(BaseConfig):
120118
precision: Literal["fp16-mixed", "bf16-mixed", "32-true"] = "fp16-mixed"
121119
# Checkpointing and logging
122120
project: str = "hivemind_debug"
121+
metric_logger_type: Literal["wandb", "dummy"] = "wandb"
123122
log_activations_steps: int | None = None
124123
ckpt: CkptConfig = CkptConfig()
125124
# Hivemind
@@ -192,14 +191,9 @@ def train(config: Config):
192191
sharding_strategy = ShardingStrategy.NO_SHARD
193192
log("Hivemind is used, ShardingStrategy.NO_SHARD is used")
194193

195-
run_id = None
196194
if rank == 0:
197-
wandb.init(project=config.project, config=config.model_dump())
198-
run_id = wandb.run.id
199-
200-
run_id_list = [run_id]
201-
broadcast_object_list(run_id_list, src=0)
202-
run_id = run_id_list[0]
195+
logger_cls = WandbLogger if config.metric_logger_type == "wandb" else DummyLogger
196+
metric_logger = logger_cls(project=config.project, config=config.model_dump())
203197

204198
if config.hv is not None:
205199
log("hivemind diloco enabled")
@@ -459,7 +453,7 @@ def scheduler_fn(opt):
459453

460454
current_time = time.time()
461455

462-
wandb.log(metrics)
456+
metric_logger.log(metrics)
463457

464458
if config.hv is None:
465459
log(
@@ -512,7 +506,7 @@ def scheduler_fn(opt):
512506
if config.max_steps is not None and real_step >= config.max_steps:
513507
break
514508
log("Training completed.")
515-
wandb.finish()
509+
metric_logger.finish()
516510

517511

518512
if __name__ == "__main__":

open_diloco/utils.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import hashlib
22
from functools import partial
3-
from typing import Any, Generator
3+
import json
4+
from typing import Any, Generator, Protocol
45

56
import torch
67
from torch.utils.hooks import RemovableHandle
78
from torch.distributed.fsdp import ShardingStrategy
89
from torch.utils.data import IterableDataset
10+
import wandb
911

1012

1113
_FSDP_WRAPPED_MODULE = ["_forward_module.", "_fsdp_wrapped_module."]
@@ -175,3 +177,39 @@ def __iter__(self) -> Generator[dict[str, Any], Any, None]:
175177
input_ids = torch.randint(3, self.vocab_size, (self.seq_len,)).tolist()
176178
attention_mask = [1] * self.seq_len
177179
yield {"input_ids": input_ids, "attention_mask": attention_mask}
180+
181+
182+
class Logger(Protocol):
183+
def __init__(self, project, config): ...
184+
185+
def log(self, metrics: dict[str, Any]): ...
186+
187+
def finish(self): ...
188+
189+
190+
class WandbLogger:
191+
def __init__(self, project, config):
192+
wandb.init(project=project, config=config)
193+
194+
def log(self, metrics: dict[str, Any]):
195+
wandb.log(metrics)
196+
197+
def finish(self):
198+
wandb.finish()
199+
200+
201+
class DummyLogger:
202+
def __init__(self, project, config):
203+
self.project = project
204+
self.config = config
205+
open(project, "a").close() # Create an empty file at the project path
206+
207+
self.data = []
208+
209+
def log(self, metrics: dict[str, Any]):
210+
self.data.append(metrics)
211+
212+
def finish(self):
213+
with open(self.project, "a") as f:
214+
for d in self.data:
215+
f.write(json.dumps(d) + "\n")

0 commit comments

Comments
 (0)