@@ -211,8 +211,19 @@ def __call__(self):
211211 @staticmethod
212212 def _get_data_chunks_workers (ctx , data ):
213213 # data_chunk.inputs is concat, and concat's input is the co-allocated chunks
214- metas = ctx .get_chunks_meta ([c .key for c in data .chunks ], fields = ["bands" ])
215- return [m ["bands" ][0 ][0 ] for m in metas ]
214+ metas = ctx .get_chunks_meta (
215+ [c .key for c in data .chunks ], fields = ["ip" , "bands" ]
216+ )
217+
218+ ips = []
219+ ip_to_worker = {}
220+ for m in metas :
221+ ip = m ["ip" ]
222+ assert ip , "There is meta {meta} who doesn't contain ip."
223+ ips .append (ip )
224+ bands = m ["bands" ]
225+ ip_to_worker [ip ] = bands [0 ][0 ] if bands else None
226+ return ips , ip_to_worker
216227
217228 @staticmethod
218229 def _concat_chunks_by_worker (chunks , chunk_workers ):
@@ -230,23 +241,24 @@ def tile(cls, op: "LGBMTrain"):
230241 data = op .data
231242 worker_to_args = defaultdict (dict )
232243
233- workers = cls ._get_data_chunks_workers (ctx , data )
244+ # Note: Mars worker is band address, and LGBMTrain worker is machine ip.
245+ ips , ip_to_worker = cls ._get_data_chunks_workers (ctx , data )
234246
235247 for arg in ["_data" , "_label" , "_sample_weight" , "_init_score" ]:
236248 if getattr (op , arg ) is not None :
237249 for worker , chunk in cls ._concat_chunks_by_worker (
238- getattr (op , arg ).chunks , workers
250+ getattr (op , arg ).chunks , ips
239251 ).items ():
240252 worker_to_args [worker ][arg ] = chunk
241253
242254 if op .eval_datas :
243255 eval_workers_list = [
244- cls ._get_data_chunks_workers (ctx , d ) for d in op .eval_datas
256+ cls ._get_data_chunks_workers (ctx , d )[ 0 ] for d in op .eval_datas
245257 ]
246258 extra_workers = reduce (
247259 operator .or_ , (set (w ) for w in eval_workers_list )
248- ) - set (workers )
249- worker_remap = dict (zip (extra_workers , itertools .cycle (workers )))
260+ ) - set (ips )
261+ worker_remap = dict (zip (extra_workers , itertools .cycle (ips )))
250262 if worker_remap :
251263 eval_workers_list = [
252264 [worker_remap .get (w , w ) for w in wl ] for wl in eval_workers_list
@@ -270,10 +282,11 @@ def tile(cls, op: "LGBMTrain"):
270282 worker_to_args [worker ][arg ].append (chunk )
271283
272284 out_chunks = []
273- workers = list (set (workers ))
274- for worker_id , worker in enumerate (workers ):
285+ ips = list (set (ips ))
286+ workers = list (ip_to_worker .values ())
287+ for worker_id , worker in enumerate (ips ):
275288 chunk_op = op .copy ().reset_key ()
276- chunk_op .expect_worker = worker
289+ chunk_op .expect_worker = ip_to_worker [ worker ]
277290
278291 input_chunks = []
279292 concat_args = worker_to_args .get (worker , {})
@@ -301,7 +314,7 @@ def tile(cls, op: "LGBMTrain"):
301314 ).chunks [0 ]
302315 input_chunks .append (worker_ports_chunk )
303316
304- chunk_op ._workers = workers
317+ chunk_op ._workers = ips
305318 chunk_op ._worker_ports = worker_ports_chunk
306319 chunk_op ._worker_id = worker_id
307320
@@ -357,9 +370,8 @@ def execute(cls, ctx, op: "LGBMTrain"):
357370 # if model is trained, remove unsupported parameters
358371 params .pop ("out_dtype_" , None )
359372 worker_ports = ctx [op .worker_ports .key ]
360- worker_ips = [worker .split (":" , 1 )[0 ] for worker in op .workers ]
361373 worker_endpoints = [
362- f"{ worker } :{ port } " for worker , port in zip (worker_ips , worker_ports )
374+ f"{ worker } :{ port } " for worker , port in zip (op . workers , worker_ports )
363375 ]
364376
365377 params ["machines" ] = "," .join (worker_endpoints )
0 commit comments