1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import asyncio
15+ import logging
16+ import os
1517
1618import numpy as np
1719import pytest
3537
3638ray = lazy_import ("ray" )
3739
40+ logger = logging .getLogger (__name__ )
41+
3842
3943@pytest .fixture
4044async def speculative_cluster ():
@@ -224,14 +228,13 @@ async def test_auto_scale_in(ray_large_cluster):
224228 )
225229 while await autoscaler_ref .get_dynamic_worker_nums () > 2 :
226230 dynamic_workers = await autoscaler_ref .get_dynamic_workers ()
227- print (f"Waiting workers { dynamic_workers } to be released." )
231+ logger . info (f"Waiting %s workers to be released." , dynamic_workers )
228232 await asyncio .sleep (1 )
229233 await asyncio .sleep (1 )
230234 assert await autoscaler_ref .get_dynamic_worker_nums () == 2
231235
232236
233- @pytest .mark .skip ("Enable it when ray ownership bug is fixed" )
234- @pytest .mark .timeout (timeout = 200 )
237+ @pytest .mark .timeout (timeout = 150 )
235238@pytest .mark .parametrize ("ray_large_cluster" , [{"num_nodes" : 4 }], indirect = True )
236239@require_ray
237240@pytest .mark .asyncio
@@ -255,23 +258,62 @@ async def test_ownership_when_scale_in(ray_large_cluster):
255258 uid = AutoscalerActor .default_uid (),
256259 address = client ._cluster .supervisor_address ,
257260 )
258- await asyncio .gather (* [autoscaler_ref .request_worker () for _ in range (2 )])
259- df = md .DataFrame (mt .random .rand (100 , 4 , chunk_size = 2 ), columns = list ("abcd" ))
260- print (df .execute ())
261- assert await autoscaler_ref .get_dynamic_worker_nums () > 1
261+ num_chunks , chunk_size = 20 , 4
262+ df = md .DataFrame (
263+ mt .random .rand (num_chunks * chunk_size , 4 , chunk_size = chunk_size ),
264+ columns = list ("abcd" ),
265+ )
266+ latch_actor = ray .remote (CountDownLatch ).remote (1 )
267+ pid = os .getpid ()
268+
269+ def f (pdf , latch ):
270+ if os .getpid () != pid :
271+ # type inference will call this function too
272+ ray .get (latch .wait .remote ())
273+ return pdf
274+
275+ df = df .map_chunk (
276+ f ,
277+ args = (latch_actor ,),
278+ )
279+ info = df .execute (wait = False )
280+ while await autoscaler_ref .get_dynamic_worker_nums () <= 1 :
281+ logger .info ("Waiting workers to be created." )
282+ await asyncio .sleep (1 )
283+ await latch_actor .count_down .remote ()
284+ await info
285+ assert info .exception () is None
286+ assert info .progress () == 1
287+ logger .info ("df execute succeed." )
288+
262289 while await autoscaler_ref .get_dynamic_worker_nums () > 1 :
263290 dynamic_workers = await autoscaler_ref .get_dynamic_workers ()
264- print ( f "Waiting workers { dynamic_workers } to be released." )
291+ logger . info ( "Waiting workers %s to be released." , dynamic_workers )
265292 await asyncio .sleep (1 )
266293 # Test data on node of released worker can still be fetched
267- pd_df = df .to_pandas ()
268- groupby_sum_df = df .rechunk (40 ).groupby ("a" ).sum ()
269- print (groupby_sum_df .execute ())
294+ pd_df = df .fetch ()
295+ groupby_sum_df = df .rechunk (chunk_size * 2 ).groupby ("a" ).sum ()
296+ logger . info (groupby_sum_df .execute ())
270297 while await autoscaler_ref .get_dynamic_worker_nums () > 1 :
271298 dynamic_workers = await autoscaler_ref .get_dynamic_workers ()
272- print (f"Waiting workers { dynamic_workers } to be released." )
299+ logger . info (f"Waiting workers %s to be released." , dynamic_workers )
273300 await asyncio .sleep (1 )
274301 assert df .to_pandas ().to_dict () == pd_df .to_dict ()
275302 assert (
276303 groupby_sum_df .to_pandas ().to_dict () == pd_df .groupby ("a" ).sum ().to_dict ()
277304 )
305+
306+
307+ class CountDownLatch :
308+ def __init__ (self , cnt ):
309+ self .cnt = cnt
310+
311+ def count_down (self ):
312+ self .cnt -= 1
313+
314+ def get_count (self ):
315+ return self .cnt
316+
317+ async def wait (self ):
318+ while self .cnt != 0 :
319+ await asyncio .sleep (0.01 )
0 commit comments