|
15 | 15 |
|
16 | 16 | from pydantic import model_validator
|
17 | 17 | import torch
|
18 |
| -import wandb |
19 | 18 | from pydantic_config import parse_argv, BaseConfig
|
20 | 19 | from datasets import load_dataset
|
21 | 20 | from datasets.distributed import split_dataset_by_node
|
|
35 | 34 | MixedPrecision,
|
36 | 35 | )
|
37 | 36 | from torch.distributed.device_mesh import DeviceMesh
|
38 |
| -from torch.distributed import broadcast_object_list |
39 | 37 | from open_diloco.ckpt_utils import (
|
40 | 38 | CKPT_PREFIX,
|
41 | 39 | CkptConfig,
|
|
46 | 44 | save_checkpoint,
|
47 | 45 | )
|
48 | 46 | from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer
|
49 |
| - |
| 47 | +from open_diloco.utils import WandbLogger, DummyLogger |
50 | 48 |
|
51 | 49 | from hivemind.dht.dht import DHT
|
52 | 50 | from hivemind.utils.networking import log_visible_maddrs
|
@@ -120,6 +118,7 @@ class Config(BaseConfig):
|
120 | 118 | precision: Literal["fp16-mixed", "bf16-mixed", "32-true"] = "fp16-mixed"
|
121 | 119 | # Checkpointing and logging
|
122 | 120 | project: str = "hivemind_debug"
|
| 121 | + metric_logger_type: Literal["wandb", "dummy"] = "wandb" |
123 | 122 | log_activations_steps: int | None = None
|
124 | 123 | ckpt: CkptConfig = CkptConfig()
|
125 | 124 | # Hivemind
|
@@ -192,14 +191,9 @@ def train(config: Config):
|
192 | 191 | sharding_strategy = ShardingStrategy.NO_SHARD
|
193 | 192 | log("Hivemind is used, ShardingStrategy.NO_SHARD is used")
|
194 | 193 |
|
195 |
| - run_id = None |
196 | 194 | if rank == 0:
|
197 |
| - wandb.init(project=config.project, config=config.model_dump()) |
198 |
| - run_id = wandb.run.id |
199 |
| - |
200 |
| - run_id_list = [run_id] |
201 |
| - broadcast_object_list(run_id_list, src=0) |
202 |
| - run_id = run_id_list[0] |
| 195 | + logger_cls = WandbLogger if config.metric_logger_type == "wandb" else DummyLogger |
| 196 | + metric_logger = logger_cls(project=config.project, config=config.model_dump()) |
203 | 197 |
|
204 | 198 | if config.hv is not None:
|
205 | 199 | log("hivemind diloco enabled")
|
@@ -459,7 +453,7 @@ def scheduler_fn(opt):
|
459 | 453 |
|
460 | 454 | current_time = time.time()
|
461 | 455 |
|
462 |
| - wandb.log(metrics) |
| 456 | + metric_logger.log(metrics) |
463 | 457 |
|
464 | 458 | if config.hv is None:
|
465 | 459 | log(
|
@@ -512,7 +506,7 @@ def scheduler_fn(opt):
|
512 | 506 | if config.max_steps is not None and real_step >= config.max_steps:
|
513 | 507 | break
|
514 | 508 | log("Training completed.")
|
515 |
| - wandb.finish() |
| 509 | + metric_logger.finish() |
516 | 510 |
|
517 | 511 |
|
518 | 512 | if __name__ == "__main__":
|
|
0 commit comments