Skip to content

Commit af06270

Browse files
committed
test: check ckpt resuming
1 parent 96a0e82 commit af06270

File tree

3 files changed

+51
-27
lines changed

3 files changed

+51
-27
lines changed

open_diloco/train_fsdp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,8 @@ def scheduler_fn(opt):
506506
if config.max_steps is not None and real_step >= config.max_steps:
507507
break
508508
log("Training completed.")
509-
metric_logger.finish()
509+
if rank == 0:
510+
metric_logger.finish()
510511

511512

512513
if __name__ == "__main__":

open_diloco/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import hashlib
22
from functools import partial
3-
import json
3+
import pickle
44
from typing import Any, Generator, Protocol
55

66
import torch
@@ -210,6 +210,5 @@ def log(self, metrics: dict[str, Any]):
210210
self.data.append(metrics)
211211

212212
def finish(self):
213-
with open(self.project, "a") as f:
214-
for d in self.data:
215-
f.write(json.dumps(d) + "\n")
213+
with open(self.project, "wb") as f:
214+
pickle.dump(self.data, f)

tests/test_training/test_train.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
1+
import pickle
12
import subprocess
3+
import numpy as np
24
import pytest
35
import socket
4-
import os
5-
from unittest import mock
66
from hivemind.dht.dht import DHT
7-
8-
9-
@pytest.fixture(autouse=True)
10-
def set_env():
11-
os.environ["WANDB_MODE"] = "disabled"
12-
13-
with mock.patch.dict(os.environ, {"WANDB_MODE": "disabled"}):
14-
yield
7+
from open_diloco.ckpt_utils import CKPT_PREFIX
158

169

1710
def get_random_available_port():
@@ -41,25 +34,54 @@ def config() -> list[str]:
4134
"16",
4235
"--max_steps",
4336
"50",
37+
"--metric_logger_type",
38+
"dummy",
4439
]
4540

4641

47-
@pytest.mark.parametrize("num_gpu", [1, 2])
48-
def test_multi_gpu(config, random_available_port, num_gpu):
49-
result = subprocess.run(
50-
[
51-
"torchrun",
52-
f"--nproc_per_node={num_gpu}",
53-
"--rdzv-endpoint",
54-
f"localhost:{random_available_port}",
55-
"open_diloco/train_fsdp.py",
56-
*config,
57-
],
58-
)
42+
@pytest.mark.parametrize("num_gpu", [2])
43+
def test_multi_gpu_ckpt(config, random_available_port, num_gpu, tmp_path):
44+
ckpt_path = f"{tmp_path}/ckpt"
45+
log_file_1 = f"{tmp_path}/log1.json"
46+
log_file_2 = f"{tmp_path}/log2.json"
47+
48+
run_1 = ["--ckpt.path", ckpt_path, "--ckpt.interval", "10", "--project", log_file_1]
49+
50+
cmd = [
51+
"torchrun",
52+
f"--nproc_per_node={num_gpu}",
53+
"--rdzv-endpoint",
54+
f"localhost:{random_available_port}",
55+
"open_diloco/train_fsdp.py",
56+
*config,
57+
]
58+
59+
result = subprocess.run(cmd + run_1)
5960

6061
if result.returncode != 0:
6162
pytest.fail(f"Process {result} failed {result.stderr}")
6263

64+
run_2 = ["--ckpt.path", ckpt_path, "--ckpt.resume", f"{ckpt_path}/{CKPT_PREFIX}_20", "--project", log_file_2]
65+
66+
results_resume = subprocess.run(cmd + run_2)
67+
68+
if results_resume.returncode != 0:
69+
pytest.fail(f"Process {result} failed {result.stderr}")
70+
71+
with open(log_file_1, "rb") as f:
72+
log1 = pickle.load(f)
73+
with open(log_file_2, "rb") as f:
74+
log2 = pickle.load(f)
75+
76+
log1 = {data["step"]: [data["Loss"], data["lr"]] for data in log1}
77+
log2 = {data["step"]: [data["Loss"], data["lr"]] for data in log2}
78+
79+
common_step = set(log1.keys()) & set(log2.keys())
80+
81+
for step in common_step:
82+
assert np.allclose(log1[step][0], log2[step][0], atol=1e-3), f"Loss at step {step} is different"
83+
assert log1[step][1] == log2[step][1], f"Lr at step {step} is different"
84+
6385

6486
@pytest.fixture
6587
def config_hv() -> list[str]:
@@ -76,6 +98,8 @@ def config_hv() -> list[str]:
7698
"16",
7799
"--max_steps",
78100
"100",
101+
"--metric_logger_type",
102+
"dummy",
79103
]
80104

81105
return config + [

0 commit comments

Comments
 (0)