1
+ import pickle
1
2
import subprocess
3
+ import numpy as np
2
4
import pytest
3
5
import socket
4
- import os
5
- from unittest import mock
6
6
from hivemind .dht .dht import DHT
7
-
8
-
9
- @pytest .fixture (autouse = True )
10
- def set_env ():
11
- os .environ ["WANDB_MODE" ] = "disabled"
12
-
13
- with mock .patch .dict (os .environ , {"WANDB_MODE" : "disabled" }):
14
- yield
7
+ from open_diloco .ckpt_utils import CKPT_PREFIX
15
8
16
9
17
10
def get_random_available_port ():
@@ -41,25 +34,54 @@ def config() -> list[str]:
41
34
"16" ,
42
35
"--max_steps" ,
43
36
"50" ,
37
+ "--metric_logger_type" ,
38
+ "dummy" ,
44
39
]
45
40
46
41
47
- @pytest .mark .parametrize ("num_gpu" , [1 , 2 ])
48
- def test_multi_gpu (config , random_available_port , num_gpu ):
49
- result = subprocess .run (
50
- [
51
- "torchrun" ,
52
- f"--nproc_per_node={ num_gpu } " ,
53
- "--rdzv-endpoint" ,
54
- f"localhost:{ random_available_port } " ,
55
- "open_diloco/train_fsdp.py" ,
56
- * config ,
57
- ],
58
- )
42
+ @pytest .mark .parametrize ("num_gpu" , [2 ])
43
+ def test_multi_gpu_ckpt (config , random_available_port , num_gpu , tmp_path ):
44
+ ckpt_path = f"{ tmp_path } /ckpt"
45
+ log_file_1 = f"{ tmp_path } /log1.json"
46
+ log_file_2 = f"{ tmp_path } /log2.json"
47
+
48
+ run_1 = ["--ckpt.path" , ckpt_path , "--ckpt.interval" , "10" , "--project" , log_file_1 ]
49
+
50
+ cmd = [
51
+ "torchrun" ,
52
+ f"--nproc_per_node={ num_gpu } " ,
53
+ "--rdzv-endpoint" ,
54
+ f"localhost:{ random_available_port } " ,
55
+ "open_diloco/train_fsdp.py" ,
56
+ * config ,
57
+ ]
58
+
59
+ result = subprocess .run (cmd + run_1 )
59
60
60
61
if result .returncode != 0 :
61
62
pytest .fail (f"Process { result } failed { result .stderr } " )
62
63
64
+ run_2 = ["--ckpt.path" , ckpt_path , "--ckpt.resume" , f"{ ckpt_path } /{ CKPT_PREFIX } _20" , "--project" , log_file_2 ]
65
+
66
+ results_resume = subprocess .run (cmd + run_2 )
67
+
68
+ if results_resume .returncode != 0 :
69
+ pytest .fail (f"Process { result } failed { result .stderr } " )
70
+
71
+ with open (log_file_1 , "rb" ) as f :
72
+ log1 = pickle .load (f )
73
+ with open (log_file_2 , "rb" ) as f :
74
+ log2 = pickle .load (f )
75
+
76
+ log1 = {data ["step" ]: [data ["Loss" ], data ["lr" ]] for data in log1 }
77
+ log2 = {data ["step" ]: [data ["Loss" ], data ["lr" ]] for data in log2 }
78
+
79
+ common_step = set (log1 .keys ()) & set (log2 .keys ())
80
+
81
+ for step in common_step :
82
+ assert np .allclose (log1 [step ][0 ], log2 [step ][0 ], atol = 1e-3 ), f"Loss at step { step } is different"
83
+ assert log1 [step ][1 ] == log2 [step ][1 ], f"Lr at step { step } is different"
84
+
63
85
64
86
@pytest .fixture
65
87
def config_hv () -> list [str ]:
@@ -76,6 +98,8 @@ def config_hv() -> list[str]:
76
98
"16" ,
77
99
"--max_steps" ,
78
100
"100" ,
101
+ "--metric_logger_type" ,
102
+ "dummy" ,
79
103
]
80
104
81
105
return config + [
0 commit comments