Skip to content

Commit efce967

Browse files
committed
Still debugging, but turned off background loading and continuing to load cache entries
1 parent 4f393f1 commit efce967

File tree

3 files changed

+70
-37
lines changed

3 files changed

+70
-37
lines changed

bridger/logging_utils/cache_entry_database.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from bridger.go_explore_phase_1 import CacheEntry
2+
import time
23

34

45
class CacheEntryDatabase:
@@ -7,27 +8,68 @@ class CacheEntryDatabase:
78
"""
89

910
def __init__(self, cache_entries: list[CacheEntry]):
11+
print(f"Initializing CacheEntryDatabase with {len(cache_entries)} entries")
12+
start_time = time.time()
13+
14+
# Debug info about the entries
15+
if cache_entries:
16+
print(f"First entry type: {type(cache_entries[0])}")
17+
print(f"First entry attributes: {dir(cache_entries[0])}")
18+
try:
19+
print(
20+
f"First entry state shape: {cache_entries[0].state_representative.shape}"
21+
)
22+
except Exception as e:
23+
print(f"Error getting state shape: {e}")
24+
1025
self.cache_entries = cache_entries
26+
end_time = time.time()
27+
print(
28+
f"CacheEntryDatabase initialization took {end_time - start_time:.2f} seconds"
29+
)
1130

1231
def sort_by_key(self, sort_key: "SortKey"):
1332
"""
1433
Sorts the cache entries using the provided sort key.
1534
"""
16-
self.cache_entries = sorted(self.cache_entries, key=sort_key)
35+
print(f"Sorting cache entries by {sort_key.key}")
36+
start_time = time.time()
37+
try:
38+
self.cache_entries = sorted(self.cache_entries, key=sort_key)
39+
end_time = time.time()
40+
print(f"Sorting completed in {end_time - start_time:.2f} seconds")
41+
except Exception as e:
42+
print(f"Error during sorting: {e}")
43+
raise
1744

1845
def get_top_n_by_sort_key(self, sort_key: "SortKey", n: int):
1946
"""
2047
Returns the top n cache entries.
2148
"""
22-
return self.sort_by_key(sort_key)[:n]
49+
print(f"Getting top {n} entries by {sort_key.key}")
50+
start_time = time.time()
51+
try:
52+
sorted_entries = self.sort_by_key(sort_key)
53+
result = sorted_entries[:n]
54+
end_time = time.time()
55+
print(f"Retrieved top {n} entries in {end_time - start_time:.2f} seconds")
56+
return result
57+
except Exception as e:
58+
print(f"Error getting top entries: {e}")
59+
raise
2360

2461

2562
class SortKey:
2663
def __init__(self, key: str):
2764
self.key = key
2865

2966
def __call__(self, cache_entry: CacheEntry):
30-
return getattr(cache_entry, self.key)
67+
try:
68+
value = getattr(cache_entry, self.key)
69+
return value
70+
except Exception as e:
71+
print(f"Error accessing {self.key} on cache entry: {e}")
72+
raise
3173

3274

3375
class TrajectorySortKey(SortKey):

bridger/logging_utils/object_log_readers.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,10 @@
2828

2929
def read_object_log(log_filepath: str):
3030
with open(log_filepath, "rb") as f:
31-
buffer = None
3231
while True:
3332
try:
34-
# Try to load with torch.load first to handle CUDA tensors
35-
try:
36-
buffer = torch.load(f, map_location=torch.device("cpu"))
37-
except:
38-
# If that fails, try regular pickle.load
39-
f.seek(0) # Reset file pointer
40-
buffer = pickle.load(f)
41-
# Move any tensors to CPU
42-
if isinstance(buffer, list):
43-
for i, item in enumerate(buffer):
44-
if isinstance(item, torch.Tensor):
45-
buffer[i] = item.cpu()
46-
elif isinstance(buffer, torch.Tensor):
47-
buffer = buffer.cpu()
33+
print(f"This is the type: {type(f)}")
34+
buffer = pickle.load(f)
4835

4936
for element in buffer:
5037
yield element
@@ -311,6 +298,7 @@ def __init__(self, dirname: str):
311298
for entry in _read_object_log(
312299
os.path.dirname(dirname), log_entry.STATE_NORMALIZED_LOG_ENTRY
313300
):
301+
print(f"")
314302
self._states[entry.id] = entry.object
315303

316304
# Store visited states sorted by visit count.

tools/web/sibyl.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
StepsSinceLedToSomethingNewResetCountSortKey,
2222
)
2323
from bridger.logging_utils.object_log_readers import read_object_log
24+
from bridger.go_explore_phase_1 import StateSampler
2425

2526
app = flask.Flask(__name__)
2627

@@ -80,30 +81,32 @@ def n_fewest_steps_since_led_to_something_new_go_explore_plot_data():
8081
"""
8182
Provides plot data for the n states that have most recently seen a new cell.
8283
"""
83-
experiment_name = _get_string_or_default(_EXPERIMENT_NAME)
8484
n = _get_int_or_default("n", 10)
85-
86-
cache_entry_database = CacheEntryDatabase(
87-
list(read_object_log(os.path.join(_LOG_DIR, "go_explore", "start_entries.pkl")))
85+
_CACHE_ENTRY_DATABASE = list(
86+
read_object_log(os.path.join(_LOG_DIR, "go_explore", "state_cache-6.pkl"))
8887
)
88+
print(_CACHE_ENTRY_DATABASE)
8989

9090
return {
9191
"states": [
92-
cache_entry.state for cache_entry in cache_entry_database.cache_entries
92+
cache_entry.object.state
93+
for cache_entry in _CACHE_ENTRY_DATABASE.cache_entries
9394
],
94-
"trajectory_length": cache_entry_database.get_top_n_by_sort_key(
95-
TrajectorySortKey, n
96-
),
97-
"steps_since_led_to_something_new": cache_entry_database.get_top_n_by_sort_key(
98-
StepsSinceLedToSomethingNewSortKey, n
99-
),
100-
"steps_since_led_to_something_new_reset_count": cache_entry_database.get_top_n_by_sort_key(
101-
StepsSinceLedToSomethingNewResetCountSortKey, n
102-
),
103-
"sample_count": cache_entry_database.get_top_n_by_sort_key(
104-
SampleCountSortKey, n
105-
),
106-
"visit_count": cache_entry_database.get_top_n_by_sort_key(VisitCountSortKey, n),
95+
# "trajectory_length": _CACHE_ENTRY_DATABASE.get_top_n_by_sort_key(
96+
# TrajectorySortKey, n
97+
# ),
98+
# "steps_since_led_to_something_new": _CACHE_ENTRY_DATABASE.get_top_n_by_sort_key(
99+
# StepsSinceLedToSomethingNewSortKey, n
100+
# ),
101+
# "steps_since_led_to_something_new_reset_count": _CACHE_ENTRY_DATABASE.get_top_n_by_sort_key(
102+
# StepsSinceLedToSomethingNewResetCountSortKey, n
103+
# ),
104+
# "sample_count": _CACHE_ENTRY_DATABASE.get_top_n_by_sort_key(
105+
# SampleCountSortKey, n
106+
# ),
107+
# "visit_count": _CACHE_ENTRY_DATABASE.get_top_n_by_sort_key(
108+
# VisitCountSortKey, n
109+
# ),
107110
}
108111

109112

@@ -447,6 +450,6 @@ def go_explore():
447450
target=_OBJECT_LOG_CACHE.convert_logs_to_saved_databases,
448451
args=(_get_experiment_names(),),
449452
)
450-
convert_logs_background_thread.start()
453+
# convert_logs_background_thread.start()
451454

452455
app.run(host="0.0.0.0", port=6006)

0 commit comments

Comments
 (0)