2828 ],
2929)
3030
31- _WORK_PER_CHUNK = 10
32-
3331
3432def _count_score (
3533 v : float , wa : float , pa : float , epsilon_1 : float , epsilon_2 : float
@@ -369,22 +367,37 @@ def rollout(
369367 return success_entries , state_sampler_cache_update
370368
371369
370+ def rollout_worker (task_queue , result_queue , rollout_func ):
371+ while True :
372+ task = task_queue .get ()
373+ if task is None :
374+ task_queue .task_done ()
375+ break
376+
377+ result = rollout_func (* task )
378+ result_queue .put (result )
379+ task_queue .task_done ()
380+ return None
381+
382+
372383def explore (
373384 object_logger : ObjectLogManager ,
374385 hparams : Any ,
375386) -> set [SuccessEntry ]:
376- """
377- Generate success entries by performing exploration in the environment.
387+ """Generate success entries by performing exploration in the environment.
378388
379389 Uses multiple processes to collect rollouts and update the state cache.
380390
381- This function runs a specified number of iterations, where in each iteration, it samples
382- start states and entries from the cache, generates random seeds for each process, and
383- collects rollouts in parallel using multiprocessing. The collected rollouts are then used
384- to update the cache and accumulate successful entries.
391+ This function generates a specified number of tasks to send to
392+ rollout workers. For each task, it samples start states and
393+ entries from the cache, generates random seeds for each process,
394+ and collects rollouts in parallel using multiprocessing. The
395+ collected rollouts are then used to update the cache and
396+ accumulate successful entries.
385397
386398 Returns:
387399 set[SuccessEntry]: A set of generated success entries.
400+
388401 """
389402 rng = np .random .default_rng (hparams .seed )
390403
@@ -410,40 +423,44 @@ def explore(
410423 state_sampler .update (state_sampler_cache_update )
411424
412425 success_entries : set [SuccessEntry ] = set ()
413- for iteration in range (hparams .go_explore_num_iterations ):
414- if iteration + 1 % 20 == 0 :
415- print (
416- f"[Iteration { iteration } ] Successes: { len (success_entries )} ({ sorted ([len (x .trajectory ) for x in success_entries ])} )"
417- )
418- x = next (sorted (success_entries , key = lambda e : len (e .trajectory )))
419- print (" Trajectory: " , x .trajectory )
420426
427+ def _get_tasks (total_task_count , new_task_count , current_best_trajectory_length ):
421428 start_entries = state_sampler .sample (
422- n = hparams .go_explore_num_samples_per_iteration
429+ n = new_task_count * hparams .go_explore_num_samples_per_worker_task
423430 )
424431 if hparams .debug :
425432 object_logger .log (
426- "start_entries-width-{hparams.env_width}.pkl" ,
427- OccurrenceLogEntry (batch_idx = iteration , object = start_entries ),
433+ f "start_entries-width-{ hparams .env_width } .pkl" ,
434+ OccurrenceLogEntry (batch_idx = total_task_count , object = start_entries ),
428435 )
429436
430437 seeds = rng .integers (low = 0 , high = 2 ** 31 , size = len (start_entries ))
431438 rngs = list (map (np .random .default_rng , seeds ))
432439
433- start_entries_chunked = chunked (start_entries , _WORK_PER_CHUNK )
434- rngs_chunked = chunked (rngs , _WORK_PER_CHUNK )
435-
436- _collect_rollouts = functools .partial (
437- rollout ,
438- rollout_params ,
439- state_sampler .current_best_trajectory_length ,
440+ start_entries_chunked = chunked (
441+ start_entries , hparams .go_explore_num_samples_per_worker_task
440442 )
443+ rngs_chunked = chunked (rngs , hparams .go_explore_num_samples_per_worker_task )
444+ return [
445+ (current_best_trajectory_length , start_entries_chunk , rngs_chunk )
446+ for start_entries_chunk , rngs_chunk in zip (
447+ start_entries_chunked , rngs_chunked
448+ )
449+ ]
450+
451+ def _process_results (count : int , final_flush : bool = False ):
452+ # Block on the first, then process up to count in total.
453+ try :
454+ for i in range (count ):
455+ if i == 0 or final_flush :
456+ rollout_success_entries , state_sampler_cache_update = (
457+ result_queue .get ()
458+ )
459+ else :
460+ rollout_success_entries , state_sampler_cache_update = (
461+ result_queue .get_nowait ()
462+ )
441463
442- with multiprocessing .Pool (processes = hparams .go_explore_num_processes ) as pool :
443- for rollout_success_entries , state_sampler_cache_update in pool .starmap (
444- _collect_rollouts ,
445- [* zip (start_entries_chunked , rngs_chunked )],
446- ):
447464 # Compile success entries from the current set of
448465 # rollouts to build out the return value for this
449466 # function.
@@ -452,6 +469,67 @@ def explore(
452469 # iteration of exploratory rollouts.
453470 state_sampler .update (state_sampler_cache_update )
454471
472+ result_queue .task_done ()
473+ except :
474+ pass
475+
476+ # Push initial tasks.
477+ task_target = hparams .go_explore_num_processes * 2
478+ task_queue = multiprocessing .JoinableQueue (maxsize = task_target )
479+ result_queue = multiprocessing .JoinableQueue ()
480+
481+ # Start worker processes.
482+ _collect_rollouts = functools .partial (rollout , rollout_params )
483+ workers = []
484+ for _ in range (hparams .go_explore_num_processes ):
485+ workers .append (
486+ multiprocessing .Process (
487+ target = rollout_worker ,
488+ args = (task_queue , result_queue , _collect_rollouts ),
489+ )
490+ )
491+ for worker in workers :
492+ worker .start ()
493+
494+ total_task_count = 0
495+ while True :
496+ if total_task_count + 1 % 20 == 0 :
497+ print (
498+ f"[Total Task Count { total_task_count } ] Successes: { len (success_entries )} ({ sorted ([len (x .trajectory ) for x in success_entries ])} )"
499+ )
500+ x = next (sorted (success_entries , key = lambda e : len (e .trajectory )))
501+ print (" Trajectory: " , x .trajectory )
502+
503+ if total_task_count == hparams .go_explore_num_tasks :
504+ # Clean up by posting sentinels.
505+ for _ in range (len (workers )):
506+ task_queue .put (None )
507+
508+ task_queue .join ()
509+
510+ # Process all completed tasks.
511+ _process_results (result_queue .qsize (), final_flush = True )
512+ result_queue .join ()
513+
514+ for worker in workers :
515+ worker .join ()
516+
517+ break
518+
519+ for task in _get_tasks (
520+ total_task_count = total_task_count ,
521+ new_task_count = task_target - task_queue .qsize (),
522+ current_best_trajectory_length = state_sampler .current_best_trajectory_length ,
523+ ):
524+ task_queue .put (task )
525+ total_task_count += 1
526+ if total_task_count == hparams .go_explore_num_tasks :
527+ break
528+
529+ # Process up to len(workers) results to balance batching and
530+ # not being stuck until all the work-in-flight is done.
531+ _process_results (len (workers ))
532+
455533 if hparams .debug :
456534 object_logger .log (
457535 f"state_cache-width-{ hparams .env_width } .pkl" ,
@@ -480,7 +558,7 @@ def explore(
480558
481559 for success_entry in success_entries :
482560 object_logger .log (
483- "success_entry-width-{hparams.env_width}.pkl" , success_entry
561+ f "success_entry-width-{ hparams .env_width } .pkl" , success_entry
484562 )
485563
486564 print (
0 commit comments