1
+ import subprocess
1
2
import pytest
3
+ import socket
2
4
import os
3
5
from unittest import mock
4
- import socket
5
- from contextlib import contextmanager
6
- import multiprocessing
7
- import copy
8
- import torch
9
- import gc
10
-
11
6
from hivemind .dht .dht import DHT
12
- from open_diloco .train_fsdp import train , Config , ddp_setup , destroy_process_group , HvConfig
13
7
14
8
15
9
@pytest .fixture (autouse = True )
@@ -20,18 +14,6 @@ def set_env():
20
14
yield
21
15
22
16
23
- @pytest .fixture (autouse = True )
24
- def memory_cleanup ():
25
- # credits to : https://github.com/pytorch/pytorch/issues/82218#issuecomment-1675254117
26
- try :
27
- gc .collect ()
28
- torch .cuda .empty_cache ()
29
- yield
30
- finally :
31
- gc .collect ()
32
- torch .cuda .empty_cache ()
33
-
34
-
35
17
def get_random_available_port ():
36
18
# https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number
37
19
with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
@@ -45,83 +27,99 @@ def random_available_port():
45
27
46
28
47
29
@pytest .fixture
48
- def config () -> Config :
49
- return Config (
50
- path_model = "tests/models/llama-2m-fresh" ,
51
- fake_data = True ,
52
- torch_compile = False ,
53
- lr = 1e-2 ,
54
- per_device_train_batch_size = 8 ,
55
- total_batch_size = 16 ,
56
- max_steps = 10 ,
30
+ def config () -> list [str ]:
31
+ return [
32
+ "--path_model" ,
33
+ "tests/models/llama-2m-fresh" ,
34
+ "--fake_data" ,
35
+ "--no-torch_compile" ,
36
+ "--lr" ,
37
+ "1e-2" ,
38
+ "--per_device_train_batch_size" ,
39
+ "8" ,
40
+ "--total_batch_size" ,
41
+ "16" ,
42
+ "--max_steps" ,
43
+ "50" ,
44
+ ]
45
+
46
+
47
+ @pytest .mark .parametrize ("num_gpu" , [1 , 2 ])
48
+ def test_multi_gpu (config , random_available_port , num_gpu ):
49
+ result = subprocess .run (
50
+ [
51
+ "torchrun" ,
52
+ f"--nproc_per_node={ num_gpu } " ,
53
+ "--rdzv-endpoint" ,
54
+ f"localhost:{ random_available_port } " ,
55
+ "open_diloco/train_fsdp.py" ,
56
+ * config ,
57
+ ],
57
58
)
58
59
59
-
60
- @contextmanager
61
- def ddp_environment (random_available_port , local_rank = 0 , world_size = 1 ):
62
- with mock .patch .dict (
63
- os .environ ,
64
- {
65
- "LOCAL_RANK" : str (local_rank ),
66
- "WORLD_SIZE" : str (world_size ),
67
- "RANK" : str (local_rank ),
68
- "MASTER_ADDR" : "localhost" ,
69
- "MASTER_PORT" : str (random_available_port ),
70
- },
71
- ):
72
- ddp_setup ()
73
- try :
74
- yield
75
- finally :
76
- destroy_process_group ()
60
+ if result .returncode != 0 :
61
+ pytest .fail (f"Process { result } failed { result .stderr } " )
77
62
78
63
79
64
@pytest .fixture
80
- def simple_ddp_environment (random_available_port ):
81
- with ddp_environment (random_available_port , local_rank = 0 , world_size = 1 ):
82
- yield
83
-
84
-
85
- def test_train (config , simple_ddp_environment ):
86
- train (config )
87
-
88
-
89
- @pytest .mark .parametrize ("world_size" , [2 ])
90
- def test_multi_gpu (config , random_available_port , world_size ):
91
- def worker (local_rank ):
92
- with ddp_environment (random_available_port , local_rank = local_rank , world_size = world_size ):
93
- train (config )
94
-
95
- processes = [multiprocessing .Process (target = worker , args = (rank ,)) for rank in range (world_size )]
96
- for p in processes :
97
- p .start ()
98
- for p in processes :
99
- p .join ()
100
-
65
+ def config_hv () -> list [str ]:
66
+ config = [
67
+ "--path_model" ,
68
+ "tests/models/llama-2m-fresh" ,
69
+ "--fake_data" ,
70
+ "--no-torch_compile" ,
71
+ "--lr" ,
72
+ "1e-2" ,
73
+ "--per_device_train_batch_size" ,
74
+ "8" ,
75
+ "--total_batch_size" ,
76
+ "16" ,
77
+ "--max_steps" ,
78
+ "100" ,
79
+ ]
80
+
81
+ return config + [
82
+ "--hv.local_steps" ,
83
+ "25" ,
84
+ "--hv.skip_load_from_peers" ,
85
+ "--hv.fail_rank_drop" ,
86
+ "--hv.matchmaking_time" ,
87
+ "5" ,
88
+ ]
89
+
90
+
91
+ @pytest .mark .parametrize ("num_diloco" , [1 , 2 ])
92
+ def test_multi_gpu_hivemind (config_hv , num_diloco ):
93
+ dht = DHT (
94
+ start = True ,
95
+ host_maddrs = [f"/ip4/0.0.0.0/tcp/{ get_random_available_port ()} " ],
96
+ )
101
97
102
- @pytest .fixture
103
- def diloco_config (config : Config ) -> Config :
104
- hv_config = HvConfig (local_steps = 5 , skip_load_from_peers = True , world_rank = 0 , galaxy_size = 1 )
105
- config .hv = hv_config
106
-
107
- return config
108
-
109
-
110
- @pytest .mark .parametrize ("galaxy_size" , [2 ])
111
- def test_diloco_train (diloco_config : Config , galaxy_size ):
112
- dht = DHT (start = True )
113
- diloco_config .hv .initial_peers = dht .get_visible_maddrs ()
114
- diloco_config .max_steps = 100
115
-
116
- def worker (world_rank ):
117
- with ddp_environment (get_random_available_port (), local_rank = 0 , world_size = 1 ):
118
- config_copy : Config = copy .deepcopy (diloco_config )
119
- config_copy .hv .galaxy_size = galaxy_size
120
- config_copy .hv .world_rank = world_rank
121
- train (config_copy )
122
-
123
- processes = [multiprocessing .Process (target = worker , args = (rank ,)) for rank in range (galaxy_size )]
124
- for p in processes :
125
- p .start ()
126
- for p in processes :
127
- p .join ()
98
+ initial_peers = str (dht .get_visible_maddrs ()[0 ])
99
+
100
+ results = []
101
+
102
+ for i in range (num_diloco ):
103
+ port = get_random_available_port ()
104
+ result = subprocess .Popen (
105
+ [
106
+ "torchrun" ,
107
+ f"--nproc_per_node={ 1 } " ,
108
+ "--rdzv-endpoint" ,
109
+ f"localhost:{ port } " ,
110
+ "open_diloco/train_fsdp.py" ,
111
+ * config_hv ,
112
+ "--hv.initial_peers" ,
113
+ initial_peers ,
114
+ "--hv.world_rank" ,
115
+ str (i ),
116
+ "--hv.galaxy_size" ,
117
+ str (num_diloco ),
118
+ ],
119
+ )
120
+ results .append (result )
121
+
122
+ for result in results :
123
+ result .wait ()
124
+ if result .returncode != 0 :
125
+ pytest .fail (f"Process { result } failed { result .stderr } " )
0 commit comments