2222import uuid
2323import os
2424import time
25+ import random
2526
2627from datafusion_ray ._datafusion_ray_internal import (
2728 RayContext as RayContextInternal ,
@@ -35,6 +36,7 @@ class RayDataFrame:
3536 def __init__ (
3637 self ,
3738 ray_internal_df : RayDataFrameInternal ,
39+ num_partitions ,
3840 batch_size = 8192 ,
3941 isolate_parititions = False ,
4042 bucket : str | None = None ,
@@ -48,6 +50,7 @@ def __init__(
4850 self .isolate_partitions = isolate_parititions
4951 self .bucket = bucket
5052 self .num_exchangers = num_exchangers
53+ self .num_partitions = num_partitions
5154
5255 def stages (self ):
5356 # create our coordinator now, which we need to create stages
@@ -56,7 +59,12 @@ def stages(self):
5659
5760 self .coord = RayStageCoordinator .options (
5861 name = "RayQueryCoordinator:" + self .coordinator_id ,
59- ).remote (self .coordinator_id , len (self ._stages ), self .num_exchangers )
62+ ).remote (
63+ self .coordinator_id ,
64+ len (self ._stages ),
65+ self .num_exchangers ,
66+ self .num_partitions ,
67+ )
6068
6169 ray .get (self .coord .start_up .remote ())
6270 print ("ray coord started up" )
@@ -74,7 +82,7 @@ def collect(self) -> list[pa.RecordBatch]:
7482
7583 last_stage = max ([stage .stage_id for stage in self ._stages ])
7684
77- ref = self .coord .get_exchanger_addr .remote (last_stage )
85+ ref = self .coord .get_exchanger_addr .remote (last_stage , partition = 0 )
7886 self .create_ray_stages ()
7987 t3 = time .time ()
8088 print (f"creating ray stage actors took { t3 - t2 } s" )
@@ -86,7 +94,7 @@ def collect(self) -> list[pa.RecordBatch]:
8694 )
8795
8896 print ("calling df execute" )
89- reader = self .df .execute ({last_stage : addr })
97+ reader = self .df .execute ({( last_stage , 0 ) : addr })
9098 print ("called df execute, got reader" )
9199 self ._batches = list (reader )
92100 self .coord .all_done .remote ()
@@ -187,6 +195,7 @@ def sql(self, query: str) -> RayDataFrame:
187195 df = self .ctx .sql (query , coordinator_id )
188196 return RayDataFrame (
189197 df ,
198+ self .ctx .get_target_partitions (),
190199 self .batch_size ,
191200 self .isolate_partitions ,
192201 self .bucket ,
@@ -205,12 +214,17 @@ def set(self, option: str, value: str) -> None:
205214@ray .remote (num_cpus = 0 )
206215class RayStageCoordinator :
207216 def __init__ (
208- self , coordinator_id : str , num_stages : int , num_exchangers : int
217+ self ,
218+ coordinator_id : str ,
219+ num_stages : int ,
220+ num_exchangers : int ,
221+ num_partitions : int ,
209222 ) -> None :
210223 self .my_id = coordinator_id
211224 self .stages = {}
212225 self .num_stages = num_stages
213226 self .num_exchangers = num_exchangers
227+ self .num_partitions = num_partitions
214228 self .runtime_env = {}
215229
216230 def start_up (self ):
@@ -221,28 +235,27 @@ def start_up(self):
221235 RayExchanger .remote (f"Exchanger #{ i } " ) for i in range (self .num_exchangers )
222236 ]
223237
224- stages_per_exchanger = max (1 , self .num_stages // self .num_exchangers )
225- print ("Stages per exchanger: " , stages_per_exchanger )
226-
227238 refs = [exchange .start_up .remote () for exchange in self .xs ]
228239
229240 # ensure we've done the necessary initialization before continuing
230241 ray .wait (refs , num_returns = len (refs ))
231242 print ("all exchanges started up" )
232243
233- self . exchanges = {}
244+ # for each possible stage, and partition, assign it to an exchanger
234245 self .exchange_addrs = {}
235- for i in range (self .num_stages ):
236- exchanger_i = min (len (self .xs ) - 1 , i // stages_per_exchanger )
237- print ("exchanger_i = " , exchanger_i )
238- self .exchanges [i ] = self .xs [exchanger_i ]
239- self .exchange_addrs [i ] = ray .get (self .xs [exchanger_i ].addr .remote ())
246+ for stage_num in range (self .num_stages ):
247+ for partition_num in range (self .num_partitions ):
248+ exchanger_idx = random .choice (range (self .num_exchangers ))
249+ self .exchange_addrs [(stage_num , partition_num )] = ray .get (
250+ self .xs [exchanger_idx ].addr .remote ()
251+ )
252+ print (self .exchange_addrs )
240253
241254 # don't wait for these
242255 [exchange .serve .remote () for exchange in self .xs ]
243256
244- def get_exchanger_addr (self , stage_num : int ):
245- return self .exchange_addrs [stage_num ]
257+ def get_exchanger_addr (self , stage_num : int , partition : int ):
258+ return self .exchange_addrs [( stage_num , partition ) ]
246259
247260 def all_done (self ):
248261 print ("calling exchangers all done" )
@@ -269,16 +282,6 @@ def new_stage(
269282 ):
270283 stage_key = f"{ stage_id } -{ shadow_partition } "
271284 try :
272- if stage_key in self .stages :
273- print (f"already started stage { stage_key } " )
274- return self .stages [stage_key ]
275-
276- exchange_addr = self .exchange_addrs [stage_id ]
277-
278- input_exchange_addrs = {
279- input_stage_id : self .exchange_addrs [input_stage_id ]
280- for input_stage_id in input_stage_ids
281- }
282285
283286 print (f"creating new stage { stage_key } from bytes { len (plan_bytes )} " )
284287 stage = RayStage .options (
@@ -287,8 +290,7 @@ def new_stage(
287290 ).remote (
288291 stage_id ,
289292 plan_bytes ,
290- exchange_addr ,
291- input_exchange_addrs ,
293+ self .exchange_addrs ,
292294 fraction ,
293295 shadow_partition ,
294296 bucket ,
@@ -337,36 +339,31 @@ def __init__(
337339 self ,
338340 stage_id : str ,
339341 plan_bytes : bytes ,
340- exchanger_addr : str ,
341- input_exchange_addrs : dict [int , str ],
342+ exchanger_addrs : dict [tuple [int , int ], str ],
342343 fraction : float ,
343344 shadow_partition = None ,
344345 bucket : str | None = None ,
345346 ):
346347
347348 from datafusion_ray ._datafusion_ray_internal import PyStage
348349
350+ self .shadow_partition = shadow_partition
351+ shadow = (
352+ f", shadowing:{ self .shadow_partition } "
353+ if self .shadow_partition is not None
354+ else ""
355+ )
356+
349357 try :
350358 self .stage_id = stage_id
351359 self .pystage = PyStage (
352360 stage_id ,
353361 plan_bytes ,
354- exchanger_addr ,
355- input_exchange_addrs ,
362+ exchanger_addrs ,
356363 shadow_partition ,
357364 bucket ,
358365 fraction ,
359366 )
360- self .shadow_partition = shadow_partition
361- shadow = (
362- f", shadowing:{ self .shadow_partition } "
363- if self .shadow_partition is not None
364- else ""
365- )
366-
367- print (
368- f"RayStage[{ self .stage_id } { shadow } ] Sending to { exchanger_addr } , consuming from { input_exchange_addrs } "
369- )
370367 except Exception as e :
371368 print (
372369 f"RayStage[{ self .stage_id } { shadow } ] Unhandled Exception in init: { e } !"
0 commit comments