8
8
from torch .distributed import rpc
9
9
from multiprocessing import Pool , get_context
10
10
from ding .compatibility import torch_ge_1121
11
- from ditk import logging
12
11
from ding .utils .system_helper import find_free_port
13
12
14
13
mq = None
@@ -26,7 +25,6 @@ def torchrpc(rank):
26
25
mq = None
27
26
address = socket .gethostbyname (socket .gethostname ())
28
27
recv_tensor_list = [None , None , None , None ]
29
- logging .getLogger ().setLevel (logging .DEBUG )
30
28
name_list = ["A" , "B" , "C" , "D" ]
31
29
32
30
if rank == 0 :
@@ -85,7 +83,6 @@ def torchrpc_cuda(rank):
85
83
recv_tensor_list = [None , None , None , None ]
86
84
name_list = ["A" , "B" ]
87
85
address = socket .gethostbyname (socket .gethostname ())
88
- logging .getLogger ().setLevel (logging .DEBUG )
89
86
90
87
if rank == 0 :
91
88
attach_to = name_list [1 :]
@@ -95,7 +92,7 @@ def torchrpc_cuda(rank):
95
92
peer_rank = int (rank == 0 ) or 0
96
93
peer_name = name_list [peer_rank ]
97
94
device_map = DeviceMap (rank , [peer_name ], [rank ], [peer_rank ])
98
- logging . debug (device_map )
95
+ print (device_map )
99
96
100
97
mq = TORCHRPCMQ (
101
98
rpc_name = name_list [rank ],
@@ -132,7 +129,6 @@ def torchrpc_args_parser(rank):
132
129
global mq
133
130
global recv_tensor_list
134
131
from ding .framework .parallel import Parallel
135
- logging .getLogger ().setLevel (logging .DEBUG )
136
132
137
133
params = Parallel ._torchrpc_args_parser (
138
134
n_parallel_workers = 1 ,
@@ -143,30 +139,30 @@ def torchrpc_args_parser(rank):
143
139
local_cuda_devices = None ,
144
140
cuda_device_map = None
145
141
)[0 ]
146
-
147
- logging .debug (params )
142
+ print (params )
148
143
149
144
# 1. If attach_to is empty, init_rpc will not block.
150
145
mq = TORCHRPCMQ (** params )
151
146
mq .listen ()
152
147
assert mq ._running
153
148
mq .stop ()
154
149
assert not mq ._running
155
- logging . debug ("[Pass] 1. If attach_to is empty, init_rpc will not block." )
150
+ print ("[Pass] 1. If attach_to is empty, init_rpc will not block." )
156
151
157
152
# 2. n_parallel_workers != len(node_ids)
158
153
try :
159
154
Parallel ._torchrpc_args_parser (n_parallel_workers = 999 , attach_to = [], node_ids = [1 , 2 ])[0 ]
160
155
except RuntimeError as e :
161
- logging .debug ("[Pass] 2. n_parallel_workers != len(node_ids)." )
156
+ print ("[Pass] 2. n_parallel_workers != len(node_ids)." )
157
+ pass
162
158
else :
163
159
assert False
164
160
165
161
# 3. len(local_cuda_devices) != n_parallel_workers
166
162
try :
167
163
Parallel ._torchrpc_args_parser (n_parallel_workers = 8 , node_ids = [1 ], local_cuda_devices = [1 , 2 , 3 ])[0 ]
168
164
except RuntimeError as e :
169
- logging . debug ("[Pass] 3. len(local_cuda_devices) != n_parallel_workers." )
165
+ print ("[Pass] 3. len(local_cuda_devices) != n_parallel_workers." )
170
166
else :
171
167
assert False
172
168
@@ -175,7 +171,7 @@ def torchrpc_args_parser(rank):
175
171
try :
176
172
Parallel ._torchrpc_args_parser (n_parallel_workers = 999 , node_ids = [1 ], use_cuda = True )[0 ]
177
173
except RuntimeError as e :
178
- logging . debug ("[Pass] 4. n_parallel_workers > gpu_nums." )
174
+ print ("[Pass] 4. n_parallel_workers > gpu_nums." )
179
175
else :
180
176
assert False
181
177
@@ -186,8 +182,7 @@ def torchrpc_args_parser(rank):
186
182
assert params ['device_maps' ].peer_name_list == ["Node_0" , "Node_0" , "Node_1" ]
187
183
assert params ['device_maps' ].our_device_list == [0 , 1 , 1 ]
188
184
assert params ['device_maps' ].peer_device_list == [0 , 2 , 4 ]
189
- # logging.debug(params['device_maps'])
190
- logging .debug ("[Pass] 5. Set custom device map." )
185
+ print ("[Pass] 5. Set custom device map." )
191
186
192
187
# 6. Set n_parallel_workers > 1
193
188
params = Parallel ._torchrpc_args_parser (n_parallel_workers = 8 , node_ids = [1 ])
@@ -201,7 +196,7 @@ def torchrpc_args_parser(rank):
201
196
params = Parallel ._torchrpc_args_parser (n_parallel_workers = 2 , node_ids = [1 ], use_cuda = True )
202
197
assert params [0 ]['use_cuda' ]
203
198
assert len (params [0 ]['device_maps' ].peer_name_list ) == DEFAULT_DEVICE_MAP_NUMS - 1
204
- logging . debug ("[Pass] 6. Set n_parallel_workers > 1." )
199
+ print ("[Pass] 6. Set n_parallel_workers > 1." )
205
200
206
201
207
202
@pytest .mark .unittest
0 commit comments