@@ -112,8 +112,8 @@ def config_hv() -> list[str]:
112
112
]
113
113
114
114
115
- @pytest .mark .parametrize ("num_diloco" , [1 , 2 ])
116
- def test_multi_gpu_hivemind (config_hv , num_diloco ):
115
+ @pytest .mark .parametrize ("num_diloco" , [2 ])
116
+ def test_multi_gpu_hivemind (config_hv , num_diloco , tmp_path ):
117
117
dht = DHT (
118
118
start = True ,
119
119
host_maddrs = [f"/ip4/0.0.0.0/tcp/{ get_random_available_port ()} " ],
@@ -123,27 +123,84 @@ def test_multi_gpu_hivemind(config_hv, num_diloco):
123
123
124
124
results = []
125
125
126
+ ckpt_path = f"{ tmp_path } /ckpt"
127
+
128
+ def get_base_cmd (i , initial_peers ):
129
+ return [
130
+ "torchrun" ,
131
+ f"--nproc_per_node={ 1 } " ,
132
+ "--rdzv-endpoint" ,
133
+ f"localhost:{ port } " ,
134
+ "open_diloco/train_fsdp.py" ,
135
+ * config_hv ,
136
+ "--hv.initial_peers" ,
137
+ initial_peers ,
138
+ "--hv.world_rank" ,
139
+ str (i ),
140
+ "--hv.galaxy_size" ,
141
+ str (num_diloco ),
142
+ ]
143
+
126
144
for i in range (num_diloco ):
127
145
port = get_random_available_port ()
128
- result = subprocess .Popen (
129
- [
130
- "torchrun" ,
131
- f"--nproc_per_node={ 1 } " ,
132
- "--rdzv-endpoint" ,
133
- f"localhost:{ port } " ,
134
- "open_diloco/train_fsdp.py" ,
135
- * config_hv ,
136
- "--hv.initial_peers" ,
137
- initial_peers ,
138
- "--hv.world_rank" ,
139
- str (i ),
140
- "--hv.galaxy_size" ,
141
- str (num_diloco ),
142
- ],
143
- )
146
+
147
+ cmd = get_base_cmd (i , initial_peers ) + [
148
+ "--ckpt.path" ,
149
+ ckpt_path ,
150
+ "--ckpt.interval" ,
151
+ "25" ,
152
+ "--project" ,
153
+ f"{ tmp_path } /log{ i } _part1.json" ,
154
+ ]
155
+
156
+ result = subprocess .Popen (cmd )
144
157
results .append (result )
145
158
146
159
for result in results :
147
160
result .wait ()
148
161
if result .returncode != 0 :
149
162
pytest .fail (f"Process { result } failed { result .stderr } " )
163
+
164
+ # resume from ckpt
165
+
166
+ dht .shutdown ()
167
+
168
+ del dht
169
+ dht = DHT (
170
+ start = True ,
171
+ host_maddrs = [f"/ip4/0.0.0.0/tcp/{ get_random_available_port ()} " ],
172
+ )
173
+ initial_peers = str (dht .get_visible_maddrs ()[0 ])
174
+
175
+ for i in range (num_diloco ):
176
+ port = get_random_available_port ()
177
+
178
+ cmd = get_base_cmd (i , initial_peers ) + [
179
+ "--ckpt.resume" ,
180
+ f"{ ckpt_path } /{ CKPT_PREFIX } _50" ,
181
+ "--project" ,
182
+ f"{ tmp_path } /log{ i } _part2.json" ,
183
+ ]
184
+
185
+ result = subprocess .Popen (cmd )
186
+ results .append (result )
187
+
188
+ for result in results :
189
+ result .wait ()
190
+ if result .returncode != 0 :
191
+ pytest .fail (f"Process { result } failed { result .stderr } " )
192
+
193
+ for i in range (num_diloco ):
194
+ with open (f"{ tmp_path } /log{ i } _part1.json" , "rb" ) as f :
195
+ log1 = pickle .load (f )
196
+ with open (f"{ tmp_path } /log{ i } _part2.json" , "rb" ) as f :
197
+ log2 = pickle .load (f )
198
+
199
+ log1 = {data ["step" ]: [data ["Loss" ], data ["lr" ]] for data in log1 }
200
+ log2 = {data ["step" ]: [data ["Loss" ], data ["lr" ]] for data in log2 }
201
+
202
+ common_step = set (log1 .keys ()) & set (log2 .keys ())
203
+
204
+ for step in common_step :
205
+ assert np .allclose (log1 [step ][0 ], log2 [step ][0 ], atol = 1e-2 ), f"Loss at step { step } is different"
206
+ assert log1 [step ][1 ] == log2 [step ][1 ], f"Lr at step { step } is different"
0 commit comments