Skip to content

Commit 4f393f1

Browse files
committed
Add visualization code
1 parent f60baf3 commit 4f393f1

File tree

7 files changed

+445
-6
lines changed

7 files changed

+445
-6
lines changed

bridger/go_explore_phase_1.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -444,10 +444,10 @@ def explore(
444444
# iteration of exploratory rollouts.
445445
state_sampler.update(state_sampler_cache_update)
446446

447-
object_logger.log(
448-
f"state_cache-{hparams.env_width}.pkl",
449-
OccurrenceLogEntry(batch_idx=0, object=state_sampler),
450-
)
447+
object_logger.log(
448+
f"state_cache-{hparams.env_width}.pkl",
449+
OccurrenceLogEntry(batch_idx=iteration, object=state_sampler),
450+
)
451451

452452
return success_entries
453453

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from bridger.go_explore_phase_1 import CacheEntry
2+
3+
4+
class CacheEntryDatabase:
5+
"""
6+
Provides different views into a cache. The cache is a list of CacheEntry objects and can be sorted by trajectory length, steps_since_led_to_something_new, visit count, or sample count.
7+
"""
8+
9+
def __init__(self, cache_entries: list[CacheEntry]):
10+
self.cache_entries = cache_entries
11+
12+
def sort_by_key(self, sort_key: "SortKey"):
13+
"""
14+
Sorts the cache entries using the provided sort key.
15+
"""
16+
self.cache_entries = sorted(self.cache_entries, key=sort_key)
17+
18+
def get_top_n_by_sort_key(self, sort_key: "SortKey", n: int):
19+
"""
20+
Returns the top n cache entries.
21+
"""
22+
return self.sort_by_key(sort_key)[:n]
23+
24+
25+
class SortKey:
26+
def __init__(self, key: str):
27+
self.key = key
28+
29+
def __call__(self, cache_entry: CacheEntry):
30+
return getattr(cache_entry, self.key)
31+
32+
33+
class TrajectorySortKey(SortKey):
34+
def __init__(self):
35+
super().__init__("trajectory")
36+
37+
38+
class StepsSinceLedToSomethingNewSortKey(SortKey):
39+
def __init__(self):
40+
super().__init__("steps_since_led_to_something_new")
41+
42+
43+
class StepsSinceLedToSomethingNewResetCountSortKey(SortKey):
44+
def __init__(self):
45+
super().__init__("steps_since_led_to_something_new_reset_count")
46+
47+
48+
class VisitCountSortKey(SortKey):
49+
def __init__(self):
50+
super().__init__("visit_count")
51+
52+
53+
class SampleCountSortKey(SortKey):
54+
def __init__(self):
55+
super().__init__("sample_count")

bridger/logging_utils/object_log_readers.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,20 @@ def read_object_log(log_filepath: str):
3131
buffer = None
3232
while True:
3333
try:
34-
buffer = pickle.load(f)
34+
# Try to load with torch.load first to handle CUDA tensors
35+
try:
36+
buffer = torch.load(f, map_location=torch.device("cpu"))
37+
except:
38+
# If that fails, try regular pickle.load
39+
f.seek(0) # Reset file pointer
40+
buffer = pickle.load(f)
41+
# Move any tensors to CPU
42+
if isinstance(buffer, list):
43+
for i, item in enumerate(buffer):
44+
if isinstance(item, torch.Tensor):
45+
buffer[i] = item.cpu()
46+
elif isinstance(buffer, torch.Tensor):
47+
buffer = buffer.cpu()
3548

3649
for element in buffer:
3750
yield element

tools/web/object_log_cache.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
ACTION_INVERSION_DATABASE_KEY = "action_inversion_database_key"
2626
TRAINING_HISTORY_DATABASE_KEY = "training_history_database_key"
27+
CACHE_ENTRY_DATABASE_KEY = "cache_entry_database_key"
2728

2829
DatabaseType = (
2930
object_log_readers.TrainingHistoryDatabase
@@ -73,6 +74,13 @@ def _load_training_history_database_from_log(
7374
def _save_database(
7475
directory: str, experiment_name: str, database: DatabaseType
7576
) -> None:
77+
"""Saves a database to disk.
78+
79+
Args:
80+
directory: The directory to save the database in.
81+
experiment_name: The name of the experiment to save.
82+
database: The database to save.
83+
"""
7684
with open(os.path.join(directory, experiment_name), "wb") as f:
7785
pickle.dump(database, f)
7886

@@ -82,8 +90,17 @@ def _database_exists(directory: str, experiment_name: str) -> bool:
8290

8391

8492
def _load_database(directory: str, experiment_name: str) -> DatabaseType:
93+
"""Loads a database from disk.
94+
95+
Args:
96+
directory: The directory to load the database from.
97+
experiment_name: The name of the experiment to load.
98+
99+
Returns:
100+
The loaded database.
101+
"""
85102
with open(os.path.join(directory, experiment_name), "rb") as f:
86-
return pickle.load(f)
103+
return pickle.load(f, map_location=torch.device("cpu"))
87104

88105

89106
def _convert_log_to_saved_database_if_necessary(

tools/web/sibyl.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@
1111
from tools.web import object_log_cache
1212
from tools.web import plot_utils
1313
from collections import Counter
14+
from bridger.go_explore_phase_1 import CacheEntry
15+
from bridger.logging_utils.cache_entry_database import (
16+
CacheEntryDatabase,
17+
TrajectorySortKey,
18+
VisitCountSortKey,
19+
SampleCountSortKey,
20+
StepsSinceLedToSomethingNewSortKey,
21+
StepsSinceLedToSomethingNewResetCountSortKey,
22+
)
23+
from bridger.logging_utils.object_log_readers import read_object_log
1424

1525
app = flask.Flask(__name__)
1626

@@ -65,6 +75,38 @@ def _get_experiment_names() -> list[str]:
6575
]
6676

6777

78+
@app.route("/n_fewest_steps_since_led_to_something_new_go_explore", methods=["GET"])
79+
def n_fewest_steps_since_led_to_something_new_go_explore_plot_data():
80+
"""
81+
Provides plot data for the n states that have most recently seen a new cell.
82+
"""
83+
experiment_name = _get_string_or_default(_EXPERIMENT_NAME)
84+
n = _get_int_or_default("n", 10)
85+
86+
cache_entry_database = CacheEntryDatabase(
87+
list(read_object_log(os.path.join(_LOG_DIR, "go_explore", "start_entries.pkl")))
88+
)
89+
90+
return {
91+
"states": [
92+
cache_entry.state for cache_entry in cache_entry_database.cache_entries
93+
],
94+
"trajectory_length": cache_entry_database.get_top_n_by_sort_key(
95+
TrajectorySortKey, n
96+
),
97+
"steps_since_led_to_something_new": cache_entry_database.get_top_n_by_sort_key(
98+
StepsSinceLedToSomethingNewSortKey, n
99+
),
100+
"steps_since_led_to_something_new_reset_count": cache_entry_database.get_top_n_by_sort_key(
101+
StepsSinceLedToSomethingNewResetCountSortKey, n
102+
),
103+
"sample_count": cache_entry_database.get_top_n_by_sort_key(
104+
SampleCountSortKey, n
105+
),
106+
"visit_count": cache_entry_database.get_top_n_by_sort_key(VisitCountSortKey, n),
107+
}
108+
109+
68110
@app.route("/training_history_plot_data", methods=["GET"])
69111
def training_history_plot_data():
70112
"""Provides plot data on states and metrics based on filters.
@@ -374,6 +416,19 @@ def action_inversion():
374416
)
375417

376418

419+
@app.route("/go_explore")
420+
def go_explore():
421+
experiment_names = _get_experiment_names()
422+
selected_experiment_name = _get_string_or_default(
423+
name=_EXPERIMENT_NAME, default=experiment_names[0]
424+
)
425+
return flask.render_template(
426+
"go_explore.html",
427+
experiment_names=experiment_names,
428+
selected_experiment_name=selected_experiment_name,
429+
)
430+
431+
377432
if __name__ == "__main__":
378433
parser = argparse.ArgumentParser(description="Load Sibyl debugger.")
379434
parser.add_argument(

0 commit comments

Comments
 (0)