Skip to content

Commit 019992d

Browse files
authored
Async go explore (#244)
make rollout workers run async relative to the cache updates in the main thread to increase throughput of rollouts.
1 parent 64c6568 commit 019992d

File tree

2 files changed

+114
-36
lines changed

2 files changed

+114
-36
lines changed

bridger/config/training.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@
109109
help_str = "The number of actions to take in an exploration rollout for go-explore."
110110
hparam_dict[key]["help"] = help_str
111111

112-
key = "go_explore_num_iterations"
112+
key = "go_explore_num_tasks"
113113
hparam_dict[key] = {"type": int, "default": 8}
114-
help_str = "The number of iterations of exploration rollouts for go-explore."
114+
help_str = "The number of exploration rollout tasks to assign for go-explore. This is effectively an iteration count."
115115
hparam_dict[key]["help"] = help_str
116116

117117
key = "go_explore_epsilon_1"
@@ -159,9 +159,9 @@
159159
help_str = "The y stride to downsample the state."
160160
hparam_dict[key]["help"] = help_str
161161

162-
key = "go_explore_num_samples_per_iteration"
163-
hparam_dict[key] = {"type": int, "default": 1200}
164-
help_str = "The number of processes to use for rollout."
162+
key = "go_explore_num_samples_per_worker_task"
163+
hparam_dict[key] = {"type": int, "default": 10}
164+
help_str = "The number of samples to send each worker as a task"
165165
hparam_dict[key]["help"] = help_str
166166

167167
key = "jitter"

bridger/go_explore_phase_1.py

Lines changed: 109 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
],
2929
)
3030

31-
_WORK_PER_CHUNK = 10
32-
3331

3432
def _count_score(
3533
v: float, wa: float, pa: float, epsilon_1: float, epsilon_2: float
@@ -369,22 +367,37 @@ def rollout(
369367
return success_entries, state_sampler_cache_update
370368

371369

370+
def rollout_worker(task_queue, result_queue, rollout_func):
371+
while True:
372+
task = task_queue.get()
373+
if task is None:
374+
task_queue.task_done()
375+
break
376+
377+
result = rollout_func(*task)
378+
result_queue.put(result)
379+
task_queue.task_done()
380+
return None
381+
382+
372383
def explore(
373384
object_logger: ObjectLogManager,
374385
hparams: Any,
375386
) -> set[SuccessEntry]:
376-
"""
377-
Generate success entries by performing exploration in the environment.
387+
"""Generate success entries by performing exploration in the environment.
378388
379389
Uses multiple processes to collect rollouts and update the state cache.
380390
381-
This function runs a specified number of iterations, where in each iteration, it samples
382-
start states and entries from the cache, generates random seeds for each process, and
383-
collects rollouts in parallel using multiprocessing. The collected rollouts are then used
384-
to update the cache and accumulate successful entries.
391+
This function generates a specified number of tasks to send to
392+
rollout workers. For each task, it samples start states and
393+
entries from the cache, generates random seeds for each process,
394+
and collects rollouts in parallel using multiprocessing. The
395+
collected rollouts are then used to update the cache and
396+
accumulate successful entries.
385397
386398
Returns:
387399
set[SuccessEntry]: A set of generated success entries.
400+
388401
"""
389402
rng = np.random.default_rng(hparams.seed)
390403

@@ -410,40 +423,44 @@ def explore(
410423
state_sampler.update(state_sampler_cache_update)
411424

412425
success_entries: set[SuccessEntry] = set()
413-
for iteration in range(hparams.go_explore_num_iterations):
414-
if iteration + 1 % 20 == 0:
415-
print(
416-
f"[Iteration {iteration}] Successes: {len(success_entries)} ({sorted([len(x.trajectory) for x in success_entries ])})"
417-
)
418-
x = next(sorted(success_entries, key=lambda e: len(e.trajectory)))
419-
print(" Trajectory: ", x.trajectory)
420426

427+
def _get_tasks(total_task_count, new_task_count, current_best_trajectory_length):
421428
start_entries = state_sampler.sample(
422-
n=hparams.go_explore_num_samples_per_iteration
429+
n=new_task_count * hparams.go_explore_num_samples_per_worker_task
423430
)
424431
if hparams.debug:
425432
object_logger.log(
426-
"start_entries-width-{hparams.env_width}.pkl",
427-
OccurrenceLogEntry(batch_idx=iteration, object=start_entries),
433+
f"start_entries-width-{hparams.env_width}.pkl",
434+
OccurrenceLogEntry(batch_idx=total_task_count, object=start_entries),
428435
)
429436

430437
seeds = rng.integers(low=0, high=2**31, size=len(start_entries))
431438
rngs = list(map(np.random.default_rng, seeds))
432439

433-
start_entries_chunked = chunked(start_entries, _WORK_PER_CHUNK)
434-
rngs_chunked = chunked(rngs, _WORK_PER_CHUNK)
435-
436-
_collect_rollouts = functools.partial(
437-
rollout,
438-
rollout_params,
439-
state_sampler.current_best_trajectory_length,
440+
start_entries_chunked = chunked(
441+
start_entries, hparams.go_explore_num_samples_per_worker_task
440442
)
443+
rngs_chunked = chunked(rngs, hparams.go_explore_num_samples_per_worker_task)
444+
return [
445+
(current_best_trajectory_length, start_entries_chunk, rngs_chunk)
446+
for start_entries_chunk, rngs_chunk in zip(
447+
start_entries_chunked, rngs_chunked
448+
)
449+
]
450+
451+
def _process_results(count: int, final_flush: bool = False):
452+
# Block on the first, then process up to count in total.
453+
try:
454+
for i in range(count):
455+
if i == 0 or final_flush:
456+
rollout_success_entries, state_sampler_cache_update = (
457+
result_queue.get()
458+
)
459+
else:
460+
rollout_success_entries, state_sampler_cache_update = (
461+
result_queue.get_nowait()
462+
)
441463

442-
with multiprocessing.Pool(processes=hparams.go_explore_num_processes) as pool:
443-
for rollout_success_entries, state_sampler_cache_update in pool.starmap(
444-
_collect_rollouts,
445-
[*zip(start_entries_chunked, rngs_chunked)],
446-
):
447464
# Compile success entries from the current set of
448465
# rollouts to build out the return value for this
449466
# function.
@@ -452,6 +469,67 @@ def explore(
452469
# iteration of exploratory rollouts.
453470
state_sampler.update(state_sampler_cache_update)
454471

472+
result_queue.task_done()
473+
except:
474+
pass
475+
476+
# Push initial tasks.
477+
task_target = hparams.go_explore_num_processes * 2
478+
task_queue = multiprocessing.JoinableQueue(maxsize=task_target)
479+
result_queue = multiprocessing.JoinableQueue()
480+
481+
# Start worker processes.
482+
_collect_rollouts = functools.partial(rollout, rollout_params)
483+
workers = []
484+
for _ in range(hparams.go_explore_num_processes):
485+
workers.append(
486+
multiprocessing.Process(
487+
target=rollout_worker,
488+
args=(task_queue, result_queue, _collect_rollouts),
489+
)
490+
)
491+
for worker in workers:
492+
worker.start()
493+
494+
total_task_count = 0
495+
while True:
496+
if total_task_count + 1 % 20 == 0:
497+
print(
498+
f"[Total Task Count {total_task_count}] Successes: {len(success_entries)} ({sorted([len(x.trajectory) for x in success_entries ])})"
499+
)
500+
x = next(sorted(success_entries, key=lambda e: len(e.trajectory)))
501+
print(" Trajectory: ", x.trajectory)
502+
503+
if total_task_count == hparams.go_explore_num_tasks:
504+
# Clean up by posting sentinels.
505+
for _ in range(len(workers)):
506+
task_queue.put(None)
507+
508+
task_queue.join()
509+
510+
# Process all completed tasks.
511+
_process_results(result_queue.qsize(), final_flush=True)
512+
result_queue.join()
513+
514+
for worker in workers:
515+
worker.join()
516+
517+
break
518+
519+
for task in _get_tasks(
520+
total_task_count=total_task_count,
521+
new_task_count=task_target - task_queue.qsize(),
522+
current_best_trajectory_length=state_sampler.current_best_trajectory_length,
523+
):
524+
task_queue.put(task)
525+
total_task_count += 1
526+
if total_task_count == hparams.go_explore_num_tasks:
527+
break
528+
529+
# Process up to len(workers) results to balance batching and
530+
# not being stuck until all the work-in-flight is done.
531+
_process_results(len(workers))
532+
455533
if hparams.debug:
456534
object_logger.log(
457535
f"state_cache-width-{hparams.env_width}.pkl",
@@ -480,7 +558,7 @@ def explore(
480558

481559
for success_entry in success_entries:
482560
object_logger.log(
483-
"success_entry-width-{hparams.env_width}.pkl", success_entry
561+
f"success_entry-width-{hparams.env_width}.pkl", success_entry
484562
)
485563

486564
print(

0 commit comments

Comments
 (0)