Skip to content

Commit 3d28ed3

Browse files
committed
h
1 parent 71d822d commit 3d28ed3

File tree

1 file changed

+34
-31
lines changed

1 file changed

+34
-31
lines changed

bridger/go_explore_phase_1.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,13 @@ def rollout_worker(task_queue, result_queue, rollout_func):
371371
while True:
372372
task = task_queue.get()
373373
if task is None:
374+
task_queue.task_done()
374375
break
375376

376377
result = rollout_func(*task)
377378
result_queue.put(result)
378-
379+
task_queue.task_done()
380+
return None
379381

380382
def explore(
381383
object_logger: ObjectLogManager,
@@ -445,20 +447,19 @@ def _get_tasks(total_task_count, new_task_count, current_best_trajectory_length)
445447
)
446448
]
447449

448-
def _process_results(count: int):
449-
count_processed = 0
450+
def _process_results(count: int, final_flush: bool = False):
451+
# Block on the first, then process up to count in total.
450452
try:
451-
for _ in range(count):
452-
if count_processed == 0:
453-
# Block on the first, then process up to count in total.
453+
for i in range(count):
454+
if i == 0 or final_flush:
454455
rollout_success_entries, state_sampler_cache_update = (
455456
result_queue.get()
456457
)
457458
else:
458459
rollout_success_entries, state_sampler_cache_update = (
459460
result_queue.get_nowait()
460461
)
461-
462+
462463
# Compile success entries from the current set of
463464
# rollouts to build out the return value for this
464465
# function.
@@ -467,16 +468,15 @@ def _process_results(count: int):
467468
# iteration of exploratory rollouts.
468469
state_sampler.update(state_sampler_cache_update)
469470

470-
count_processed += 1
471+
result_queue.task_done()
471472
except:
472473
pass
473474

474-
return count_processed
475475

476476
# Push initial tasks.
477477
task_target = hparams.go_explore_num_processes * 2
478-
task_queue = multiprocessing.Queue(maxsize=task_target)
479-
result_queue = multiprocessing.Queue()
478+
task_queue = multiprocessing.JoinableQueue(maxsize=task_target)
479+
result_queue = multiprocessing.JoinableQueue()
480480

481481
# Start worker processes.
482482
_collect_rollouts = functools.partial(rollout, rollout_params)
@@ -492,7 +492,6 @@ def _process_results(count: int):
492492
worker.start()
493493

494494
total_task_count = 0
495-
tasks_in_flight = 0
496495
while True:
497496
if total_task_count + 1 % 20 == 0:
498497
print(
@@ -501,34 +500,38 @@ def _process_results(count: int):
501500
x = next(sorted(success_entries, key=lambda e: len(e.trajectory)))
502501
print(" Trajectory: ", x.trajectory)
503502

504-
process_result_count = len(workers)
505-
if total_task_count < hparams.go_explore_num_tasks:
506-
for task in _get_tasks(
507-
total_task_count=total_task_count,
508-
new_task_count=task_target - tasks_in_flight,
509-
current_best_trajectory_length=state_sampler.current_best_trajectory_length,
510-
):
511-
task_queue.put(task)
512-
tasks_in_flight += 1
513-
total_task_count += 1
514-
if total_task_count == hparams.go_explore_num_tasks:
515-
break
516-
else:
503+
504+
if total_task_count == hparams.go_explore_num_tasks:
517505
# Clean up by posting sentinels.
518506
for _ in range(len(workers)):
519507
task_queue.put(None)
508+
509+
task_queue.join()
510+
511+
# Process all completed tasks.
512+
_process_results(result_queue.qsize(), final_flush=True)
513+
result_queue.join()
514+
520515
for worker in workers:
521516
worker.join()
522-
# Process all completed tasks.
523-
tasks_in_flight -= _process_results(tasks_in_flight)
524-
break
525517

518+
break
519+
520+
for task in _get_tasks(
521+
total_task_count=total_task_count,
522+
new_task_count=task_target - task_queue.qsize(),
523+
current_best_trajectory_length=state_sampler.current_best_trajectory_length,
524+
):
525+
task_queue.put(task)
526+
total_task_count += 1
527+
if total_task_count == hparams.go_explore_num_tasks:
528+
break
529+
526530
# Process up to len(workers) results to balance batching and
527531
# not being stuck until all the work-in-flight is done.
528-
tasks_in_flight -= _process_results(len(workers))
529-
530-
assert tasks_in_flight == 0
532+
_process_results(len(workers))
531533

534+
532535
if hparams.debug:
533536
object_logger.log(
534537
f"state_cache-width-{hparams.env_width}.pkl",

0 commit comments

Comments
 (0)