Skip to content

Commit 35cd120

Browse files
committed
test: check ckpt resuming hivemind
1 parent af06270 commit 35cd120

File tree

1 file changed

+75
-18
lines changed

1 file changed

+75
-18
lines changed

tests/test_training/test_train.py

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def config_hv() -> list[str]:
112112
]
113113

114114

115-
@pytest.mark.parametrize("num_diloco", [1, 2])
116-
def test_multi_gpu_hivemind(config_hv, num_diloco):
115+
@pytest.mark.parametrize("num_diloco", [2])
116+
def test_multi_gpu_hivemind(config_hv, num_diloco, tmp_path):
117117
dht = DHT(
118118
start=True,
119119
host_maddrs=[f"/ip4/0.0.0.0/tcp/{get_random_available_port()}"],
@@ -123,27 +123,84 @@ def test_multi_gpu_hivemind(config_hv, num_diloco):
123123

124124
results = []
125125

126+
ckpt_path = f"{tmp_path}/ckpt"
127+
128+
def get_base_cmd(i, initial_peers):
129+
return [
130+
"torchrun",
131+
f"--nproc_per_node={1}",
132+
"--rdzv-endpoint",
133+
f"localhost:{port}",
134+
"open_diloco/train_fsdp.py",
135+
*config_hv,
136+
"--hv.initial_peers",
137+
initial_peers,
138+
"--hv.world_rank",
139+
str(i),
140+
"--hv.galaxy_size",
141+
str(num_diloco),
142+
]
143+
126144
for i in range(num_diloco):
127145
port = get_random_available_port()
128-
result = subprocess.Popen(
129-
[
130-
"torchrun",
131-
f"--nproc_per_node={1}",
132-
"--rdzv-endpoint",
133-
f"localhost:{port}",
134-
"open_diloco/train_fsdp.py",
135-
*config_hv,
136-
"--hv.initial_peers",
137-
initial_peers,
138-
"--hv.world_rank",
139-
str(i),
140-
"--hv.galaxy_size",
141-
str(num_diloco),
142-
],
143-
)
146+
147+
cmd = get_base_cmd(i, initial_peers) + [
148+
"--ckpt.path",
149+
ckpt_path,
150+
"--ckpt.interval",
151+
"25",
152+
"--project",
153+
f"{tmp_path}/log{i}_part1.json",
154+
]
155+
156+
result = subprocess.Popen(cmd)
144157
results.append(result)
145158

146159
for result in results:
147160
result.wait()
148161
if result.returncode != 0:
149162
pytest.fail(f"Process {result} failed {result.stderr}")
163+
164+
# resume from ckpt
165+
166+
dht.shutdown()
167+
168+
del dht
169+
dht = DHT(
170+
start=True,
171+
host_maddrs=[f"/ip4/0.0.0.0/tcp/{get_random_available_port()}"],
172+
)
173+
initial_peers = str(dht.get_visible_maddrs()[0])
174+
175+
for i in range(num_diloco):
176+
port = get_random_available_port()
177+
178+
cmd = get_base_cmd(i, initial_peers) + [
179+
"--ckpt.resume",
180+
f"{ckpt_path}/{CKPT_PREFIX}_50",
181+
"--project",
182+
f"{tmp_path}/log{i}_part2.json",
183+
]
184+
185+
result = subprocess.Popen(cmd)
186+
results.append(result)
187+
188+
for result in results:
189+
result.wait()
190+
if result.returncode != 0:
191+
pytest.fail(f"Process {result} failed {result.stderr}")
192+
193+
for i in range(num_diloco):
194+
with open(f"{tmp_path}/log{i}_part1.json", "rb") as f:
195+
log1 = pickle.load(f)
196+
with open(f"{tmp_path}/log{i}_part2.json", "rb") as f:
197+
log2 = pickle.load(f)
198+
199+
log1 = {data["step"]: [data["Loss"], data["lr"]] for data in log1}
200+
log2 = {data["step"]: [data["Loss"], data["lr"]] for data in log2}
201+
202+
common_step = set(log1.keys()) & set(log2.keys())
203+
204+
for step in common_step:
205+
assert np.allclose(log1[step][0], log2[step][0], atol=1e-2), f"Loss at step {step} is different"
206+
assert log1[step][1] == log2[step][1], f"Lr at step {step} is different"

0 commit comments

Comments
 (0)