|
11 | 11 | from tools.web import object_log_cache |
12 | 12 | from tools.web import plot_utils |
13 | 13 | from collections import Counter |
| 14 | +from bridger.go_explore_phase_1 import CacheEntry |
| 15 | +from bridger.logging_utils.cache_entry_database import ( |
| 16 | + CacheEntryDatabase, |
| 17 | + TrajectorySortKey, |
| 18 | + VisitCountSortKey, |
| 19 | + SampleCountSortKey, |
| 20 | + StepsSinceLedToSomethingNewSortKey, |
| 21 | + StepsSinceLedToSomethingNewResetCountSortKey, |
| 22 | +) |
| 23 | +from bridger.logging_utils.object_log_readers import read_object_log |
14 | 24 |
|
15 | 25 | app = flask.Flask(__name__) |
16 | 26 |
|
@@ -65,6 +75,38 @@ def _get_experiment_names() -> list[str]: |
65 | 75 | ] |
66 | 76 |
|
67 | 77 |
|
| 78 | +@app.route("/n_fewest_steps_since_led_to_something_new_go_explore", methods=["GET"]) |
| 79 | +def n_fewest_steps_since_led_to_something_new_go_explore_plot_data(): |
| 80 | + """ |
| 81 | + Provides plot data for the n states that have most recently seen a new cell. |
| 82 | + """ |
| 83 | + experiment_name = _get_string_or_default(_EXPERIMENT_NAME) |
| 84 | + n = _get_int_or_default("n", 10) |
| 85 | + |
| 86 | + cache_entry_database = CacheEntryDatabase( |
| 87 | + list(read_object_log(os.path.join(_LOG_DIR, "go_explore", "start_entries.pkl"))) |
| 88 | + ) |
| 89 | + |
| 90 | + return { |
| 91 | + "states": [ |
| 92 | + cache_entry.state for cache_entry in cache_entry_database.cache_entries |
| 93 | + ], |
| 94 | + "trajectory_length": cache_entry_database.get_top_n_by_sort_key( |
| 95 | + TrajectorySortKey, n |
| 96 | + ), |
| 97 | + "steps_since_led_to_something_new": cache_entry_database.get_top_n_by_sort_key( |
| 98 | + StepsSinceLedToSomethingNewSortKey, n |
| 99 | + ), |
| 100 | + "steps_since_led_to_something_new_reset_count": cache_entry_database.get_top_n_by_sort_key( |
| 101 | + StepsSinceLedToSomethingNewResetCountSortKey, n |
| 102 | + ), |
| 103 | + "sample_count": cache_entry_database.get_top_n_by_sort_key( |
| 104 | + SampleCountSortKey, n |
| 105 | + ), |
| 106 | + "visit_count": cache_entry_database.get_top_n_by_sort_key(VisitCountSortKey, n), |
| 107 | + } |
| 108 | + |
| 109 | + |
68 | 110 | @app.route("/training_history_plot_data", methods=["GET"]) |
69 | 111 | def training_history_plot_data(): |
70 | 112 | """Provides plot data on states and metrics based on filters. |
@@ -374,6 +416,19 @@ def action_inversion(): |
374 | 416 | ) |
375 | 417 |
|
376 | 418 |
|
| 419 | +@app.route("/go_explore") |
| 420 | +def go_explore(): |
| 421 | + experiment_names = _get_experiment_names() |
| 422 | + selected_experiment_name = _get_string_or_default( |
| 423 | + name=_EXPERIMENT_NAME, default=experiment_names[0] |
| 424 | + ) |
| 425 | + return flask.render_template( |
| 426 | + "go_explore.html", |
| 427 | + experiment_names=experiment_names, |
| 428 | + selected_experiment_name=selected_experiment_name, |
| 429 | + ) |
| 430 | + |
| 431 | + |
377 | 432 | if __name__ == "__main__": |
378 | 433 | parser = argparse.ArgumentParser(description="Load Sibyl debugger.") |
379 | 434 | parser.add_argument( |
|
0 commit comments