@@ -371,11 +371,13 @@ def rollout_worker(task_queue, result_queue, rollout_func):
371371 while True :
372372 task = task_queue .get ()
373373 if task is None :
374+ task_queue .task_done ()
374375 break
375376
376377 result = rollout_func (* task )
377378 result_queue .put (result )
378-
379+ task_queue .task_done ()
380+ return None
379381
380382def explore (
381383 object_logger : ObjectLogManager ,
@@ -445,20 +447,19 @@ def _get_tasks(total_task_count, new_task_count, current_best_trajectory_length)
445447 )
446448 ]
447449
448- def _process_results (count : int ):
449- count_processed = 0
450+ def _process_results (count : int , final_flush : bool = False ):
451+ # Block on the first, then process up to count in total.
450452 try :
451- for _ in range (count ):
452- if count_processed == 0 :
453- # Block on the first, then process up to count in total.
453+ for i in range (count ):
454+ if i == 0 or final_flush :
454455 rollout_success_entries , state_sampler_cache_update = (
455456 result_queue .get ()
456457 )
457458 else :
458459 rollout_success_entries , state_sampler_cache_update = (
459460 result_queue .get_nowait ()
460461 )
461-
462+
462463 # Compile success entries from the current set of
463464 # rollouts to build out the return value for this
464465 # function.
@@ -467,16 +468,15 @@ def _process_results(count: int):
467468 # iteration of exploratory rollouts.
468469 state_sampler .update (state_sampler_cache_update )
469470
470- count_processed += 1
471+ result_queue . task_done ()
471472 except :
472473 pass
473474
474- return count_processed
475475
476476 # Push initial tasks.
477477 task_target = hparams .go_explore_num_processes * 2
478- task_queue = multiprocessing .Queue (maxsize = task_target )
479- result_queue = multiprocessing .Queue ()
478+ task_queue = multiprocessing .JoinableQueue (maxsize = task_target )
479+ result_queue = multiprocessing .JoinableQueue ()
480480
481481 # Start worker processes.
482482 _collect_rollouts = functools .partial (rollout , rollout_params )
@@ -492,7 +492,6 @@ def _process_results(count: int):
492492 worker .start ()
493493
494494 total_task_count = 0
495- tasks_in_flight = 0
496495 while True :
497496 if total_task_count + 1 % 20 == 0 :
498497 print (
@@ -501,34 +500,38 @@ def _process_results(count: int):
501500 x = next (sorted (success_entries , key = lambda e : len (e .trajectory )))
502501 print (" Trajectory: " , x .trajectory )
503502
504- process_result_count = len (workers )
505- if total_task_count < hparams .go_explore_num_tasks :
506- for task in _get_tasks (
507- total_task_count = total_task_count ,
508- new_task_count = task_target - tasks_in_flight ,
509- current_best_trajectory_length = state_sampler .current_best_trajectory_length ,
510- ):
511- task_queue .put (task )
512- tasks_in_flight += 1
513- total_task_count += 1
514- if total_task_count == hparams .go_explore_num_tasks :
515- break
516- else :
503+
504+ if total_task_count == hparams .go_explore_num_tasks :
517505 # Clean up by posting sentinels.
518506 for _ in range (len (workers )):
519507 task_queue .put (None )
508+
509+ task_queue .join ()
510+
511+ # Process all completed tasks.
512+ _process_results (result_queue .qsize (), final_flush = True )
513+ result_queue .join ()
514+
520515 for worker in workers :
521516 worker .join ()
522- # Process all completed tasks.
523- tasks_in_flight -= _process_results (tasks_in_flight )
524- break
525517
518+ break
519+
520+ for task in _get_tasks (
521+ total_task_count = total_task_count ,
522+ new_task_count = task_target - task_queue .qsize (),
523+ current_best_trajectory_length = state_sampler .current_best_trajectory_length ,
524+ ):
525+ task_queue .put (task )
526+ total_task_count += 1
527+ if total_task_count == hparams .go_explore_num_tasks :
528+ break
529+
526530 # Process up to len(workers) results to balance batching and
527531 # not being stuck until all the work-in-flight is done.
528- tasks_in_flight -= _process_results (len (workers ))
529-
530- assert tasks_in_flight == 0
532+ _process_results (len (workers ))
531533
534+
532535 if hparams .debug :
533536 object_logger .log (
534537 f"state_cache-width-{ hparams .env_width } .pkl" ,
0 commit comments