Skip to content

Commit fb18a5e

Browse files
committed
refactor(testing): use torchrun during testing
1 parent 8b7b6a8 commit fb18a5e

File tree

1 file changed

+92
-94
lines changed

1 file changed

+92
-94
lines changed

tests/test_training/test_train.py

Lines changed: 92 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
1+
import subprocess
12
import pytest
3+
import socket
24
import os
35
from unittest import mock
4-
import socket
5-
from contextlib import contextmanager
6-
import multiprocessing
7-
import copy
8-
import torch
9-
import gc
10-
116
from hivemind.dht.dht import DHT
12-
from open_diloco.train_fsdp import train, Config, ddp_setup, destroy_process_group, HvConfig
137

148

159
@pytest.fixture(autouse=True)
@@ -20,18 +14,6 @@ def set_env():
2014
yield
2115

2216

23-
@pytest.fixture(autouse=True)
24-
def memory_cleanup():
25-
# credits to : https://github.com/pytorch/pytorch/issues/82218#issuecomment-1675254117
26-
try:
27-
gc.collect()
28-
torch.cuda.empty_cache()
29-
yield
30-
finally:
31-
gc.collect()
32-
torch.cuda.empty_cache()
33-
34-
3517
def get_random_available_port():
3618
# https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number
3719
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -45,83 +27,99 @@ def random_available_port():
4527

4628

4729
@pytest.fixture
48-
def config() -> Config:
49-
return Config(
50-
path_model="tests/models/llama-2m-fresh",
51-
fake_data=True,
52-
torch_compile=False,
53-
lr=1e-2,
54-
per_device_train_batch_size=8,
55-
total_batch_size=16,
56-
max_steps=10,
30+
def config() -> list[str]:
31+
return [
32+
"--path_model",
33+
"tests/models/llama-2m-fresh",
34+
"--fake_data",
35+
"--no-torch_compile",
36+
"--lr",
37+
"1e-2",
38+
"--per_device_train_batch_size",
39+
"8",
40+
"--total_batch_size",
41+
"16",
42+
"--max_steps",
43+
"50",
44+
]
45+
46+
47+
@pytest.mark.parametrize("num_gpu", [1, 2])
48+
def test_multi_gpu(config, random_available_port, num_gpu):
49+
result = subprocess.run(
50+
[
51+
"torchrun",
52+
f"--nproc_per_node={num_gpu}",
53+
"--rdzv-endpoint",
54+
f"localhost:{random_available_port}",
55+
"open_diloco/train_fsdp.py",
56+
*config,
57+
],
5758
)
5859

59-
60-
@contextmanager
61-
def ddp_environment(random_available_port, local_rank=0, world_size=1):
62-
with mock.patch.dict(
63-
os.environ,
64-
{
65-
"LOCAL_RANK": str(local_rank),
66-
"WORLD_SIZE": str(world_size),
67-
"RANK": str(local_rank),
68-
"MASTER_ADDR": "localhost",
69-
"MASTER_PORT": str(random_available_port),
70-
},
71-
):
72-
ddp_setup()
73-
try:
74-
yield
75-
finally:
76-
destroy_process_group()
60+
if result.returncode != 0:
61+
pytest.fail(f"Process {result} failed {result.stderr}")
7762

7863

7964
@pytest.fixture
80-
def simple_ddp_environment(random_available_port):
81-
with ddp_environment(random_available_port, local_rank=0, world_size=1):
82-
yield
83-
84-
85-
def test_train(config, simple_ddp_environment):
86-
train(config)
87-
88-
89-
@pytest.mark.parametrize("world_size", [2])
90-
def test_multi_gpu(config, random_available_port, world_size):
91-
def worker(local_rank):
92-
with ddp_environment(random_available_port, local_rank=local_rank, world_size=world_size):
93-
train(config)
94-
95-
processes = [multiprocessing.Process(target=worker, args=(rank,)) for rank in range(world_size)]
96-
for p in processes:
97-
p.start()
98-
for p in processes:
99-
p.join()
100-
65+
def config_hv() -> list[str]:
66+
config = [
67+
"--path_model",
68+
"tests/models/llama-2m-fresh",
69+
"--fake_data",
70+
"--no-torch_compile",
71+
"--lr",
72+
"1e-2",
73+
"--per_device_train_batch_size",
74+
"8",
75+
"--total_batch_size",
76+
"16",
77+
"--max_steps",
78+
"100",
79+
]
80+
81+
return config + [
82+
"--hv.local_steps",
83+
"25",
84+
"--hv.skip_load_from_peers",
85+
"--hv.fail_rank_drop",
86+
"--hv.matchmaking_time",
87+
"5",
88+
]
89+
90+
91+
@pytest.mark.parametrize("num_diloco", [1, 2])
92+
def test_multi_gpu_hivemind(config_hv, num_diloco):
93+
dht = DHT(
94+
start=True,
95+
host_maddrs=[f"/ip4/0.0.0.0/tcp/{get_random_available_port()}"],
96+
)
10197

102-
@pytest.fixture
103-
def diloco_config(config: Config) -> Config:
104-
hv_config = HvConfig(local_steps=5, skip_load_from_peers=True, world_rank=0, galaxy_size=1)
105-
config.hv = hv_config
106-
107-
return config
108-
109-
110-
@pytest.mark.parametrize("galaxy_size", [2])
111-
def test_diloco_train(diloco_config: Config, galaxy_size):
112-
dht = DHT(start=True)
113-
diloco_config.hv.initial_peers = dht.get_visible_maddrs()
114-
diloco_config.max_steps = 100
115-
116-
def worker(world_rank):
117-
with ddp_environment(get_random_available_port(), local_rank=0, world_size=1):
118-
config_copy: Config = copy.deepcopy(diloco_config)
119-
config_copy.hv.galaxy_size = galaxy_size
120-
config_copy.hv.world_rank = world_rank
121-
train(config_copy)
122-
123-
processes = [multiprocessing.Process(target=worker, args=(rank,)) for rank in range(galaxy_size)]
124-
for p in processes:
125-
p.start()
126-
for p in processes:
127-
p.join()
98+
initial_peers = str(dht.get_visible_maddrs()[0])
99+
100+
results = []
101+
102+
for i in range(num_diloco):
103+
port = get_random_available_port()
104+
result = subprocess.Popen(
105+
[
106+
"torchrun",
107+
f"--nproc_per_node={1}",
108+
"--rdzv-endpoint",
109+
f"localhost:{port}",
110+
"open_diloco/train_fsdp.py",
111+
*config_hv,
112+
"--hv.initial_peers",
113+
initial_peers,
114+
"--hv.world_rank",
115+
str(i),
116+
"--hv.galaxy_size",
117+
str(num_diloco),
118+
],
119+
)
120+
results.append(result)
121+
122+
for result in results:
123+
result.wait()
124+
if result.returncode != 0:
125+
pytest.fail(f"Process {result} failed {result.stderr}")

0 commit comments

Comments
 (0)