Skip to content

Commit 141f5fa

Browse files
committed
Some buggy visualization code
1 parent 7f4d82c commit 141f5fa

File tree

3 files changed

+191
-96
lines changed

3 files changed

+191
-96
lines changed

bridger/logging_utils/cache_entry_database.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from bridger.go_explore_phase_1 import CacheEntry
22
import time
3+
import numpy as np
34

45

56
class CacheEntryDatabase:
@@ -31,18 +32,53 @@ def sort_by_key(self, sort_key: "SortKey"):
3132
print(f"Error during sorting: {e}")
3233
raise
3334

34-
def get_top_n_by_sort_key(self, sort_key: "SortKey", n: int):
35+
def get_top_n_by_sort_key(
36+
self, sort_key: "SortKey", n: int, ascending: bool = False
37+
):
3538
"""
36-
Returns the top n cache entries.
39+
Returns the top n cache entries along with their metric values.
40+
41+
Args:
42+
sort_key: The key to sort by
43+
n: Number of entries to return
44+
ascending: If True, sort in ascending order, otherwise descending (default)
45+
46+
Returns:
47+
A dictionary containing:
48+
- states: List of state representations
49+
- values: List of corresponding metric values
3750
"""
3851
print(f"Getting top {n} entries by {sort_key.key}")
3952
start_time = time.time()
4053
try:
4154
self.sort_by_key(sort_key)
42-
result = [
43-
cache_entry.state_representative.tolist()
44-
for cache_entry in self.cache_entries[:n].copy()
45-
]
55+
if ascending:
56+
entries = self.cache_entries[:n].copy()
57+
else:
58+
entries = self.cache_entries[-n:].copy()
59+
entries.reverse()
60+
61+
# Convert values to Python native types, handling both numbers and tuples
62+
def convert_value(value):
63+
if isinstance(value, (np.int64, np.int32)):
64+
return int(value)
65+
elif isinstance(value, tuple):
66+
return [
67+
int(x) if isinstance(x, (np.int64, np.int32)) else x
68+
for x in value
69+
]
70+
return value
71+
72+
result = {
73+
"states": [
74+
cache_entry.state_representative.tolist() for cache_entry in entries
75+
],
76+
"values": [
77+
convert_value(getattr(cache_entry, sort_key.key))
78+
for cache_entry in entries
79+
],
80+
}
81+
4682
end_time = time.time()
4783
print(f"Retrieved top {n} entries in {end_time - start_time:.2f} seconds")
4884
return result

tools/web/sibyl.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import threading
55
import time
66
import json
7+
import numpy as np
78

89
from typing import Any, Optional
910

@@ -90,25 +91,31 @@ def go_explore_plot_data():
9091
cache_entries = [cache_entry for cache in caches for cache_entry in cache]
9192
_CACHE_ENTRY_DATABASE = CacheEntryDatabase(cache_entries)
9293

94+
# Convert all states to lists and ensure numeric values are Python native types
95+
states = [
96+
[
97+
int(x) if isinstance(x, (np.int64, np.int32)) else x
98+
for x in cache_entry.state_representative.tolist()
99+
]
100+
for cache_entry in _CACHE_ENTRY_DATABASE.cache_entries
101+
]
102+
93103
return {
94-
"states": [
95-
cache_entry.state_representative.tolist()
96-
for cache_entry in _CACHE_ENTRY_DATABASE.cache_entries
97-
],
104+
"states": states,
98105
"trajectory_length": _CACHE_ENTRY_DATABASE.get_top_n_by_sort_key(
99-
TrajectorySortKey(), n
106+
TrajectorySortKey(), n, ascending=False
100107
),
101108
"steps_since_led_to_something_new": _CACHE_ENTRY_DATABASE.get_top_n_by_sort_key(
102-
StepsSinceLedToSomethingNewSortKey(), n
109+
StepsSinceLedToSomethingNewSortKey(), n, ascending=False
103110
),
104111
"steps_since_led_to_something_new_reset_count": _CACHE_ENTRY_DATABASE.get_top_n_by_sort_key(
105-
StepsSinceLedToSomethingNewResetCountSortKey(), n
112+
StepsSinceLedToSomethingNewResetCountSortKey(), n, ascending=False
106113
),
107114
"sampled_count": _CACHE_ENTRY_DATABASE.get_top_n_by_sort_key(
108-
SampleCountSortKey(), n
115+
SampleCountSortKey(), n, ascending=False
109116
),
110117
"visit_count": _CACHE_ENTRY_DATABASE.get_top_n_by_sort_key(
111-
VisitCountSortKey(), n
118+
VisitCountSortKey(), n, ascending=False
112119
),
113120
}
114121

@@ -422,8 +429,8 @@ def action_inversion():
422429
)
423430

424431

425-
@app.route("/go_explore")
426-
def go_explore():
432+
@app.route("/go_explore_visualization")
433+
def go_explore_visualization():
427434
experiment_names = _get_experiment_names()
428435
selected_experiment_name = _get_string_or_default(
429436
name=_EXPERIMENT_NAME, default=experiment_names[0]

tools/web/static/js/go_explore_plots.js

Lines changed: 131 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,36 @@ const COLORS = [
4646
];
4747

4848
let data = null;
49+
let currentMetric = "trajectory_length"; // Default metric
50+
51+
// Available metrics with their display names
52+
const METRICS = [
53+
{ key: "trajectory_length", label: "Trajectory Length" },
54+
{ key: "steps_since_led_to_something_new", label: "Steps Since New Cell" },
55+
{ key: "steps_since_led_to_something_new_reset_count", label: "Reset Count" },
56+
{ key: "sampled_count", label: "Sample Count" },
57+
{ key: "visit_count", label: "Visit Count" },
58+
];
4959

5060
// Function to update plots with new data
5161
function updatePlots() {
52-
const n = document.getElementById("n-states").value || 10;
62+
const nInput = document.getElementById("n-states");
63+
const metricSelector = document.getElementById("metric-selector");
64+
65+
if (!nInput || !metricSelector) {
66+
console.error("Required DOM elements not found");
67+
return;
68+
}
69+
70+
const n = nInput.value || 10;
71+
currentMetric = metricSelector.value;
5372

5473
// Show loading indicator
5574
document.getElementById("loading").classList.remove("hidden");
75+
document.getElementById("load-error").classList.add("hidden");
5676

5777
// Fetch data from endpoint
58-
fetch(
59-
`${_ROOT_URL}n_fewest_steps_since_led_to_something_new_go_explore?n=${n}`
60-
)
78+
fetch(`${_ROOT_URL}go_explore?n=${n}`)
6179
.then((response) => response.json())
6280
.then((responseData) => {
6381
data = responseData;
@@ -87,71 +105,45 @@ function renderPlots() {
87105
// Clear existing plots
88106
container.innerHTML = "";
89107

90-
// Create plots for each metric
91-
const metrics = [
92-
{ key: "trajectory_length", label: "Trajectory Length" },
93-
{ key: "steps_since_led_to_something_new", label: "Steps Since New Cell" },
94-
{
95-
key: "steps_since_led_to_something_new_reset_count",
96-
label: "Reset Count",
97-
},
98-
{ key: "sample_count", label: "Sample Count" },
99-
{ key: "visit_count", label: "Visit Count" },
100-
];
101-
102-
metrics.forEach((metric, index) => {
103-
const plotDiv = document.createElement("div");
104-
plotDiv.className = "plot-container";
105-
container.appendChild(plotDiv);
106-
108+
// Get the current metric data
109+
const metricData = data[currentMetric];
110+
if (!metricData || !metricData.states || !metricData.values) return;
111+
112+
// Find the metric label
113+
const metricLabel =
114+
METRICS.find((m) => m.key === currentMetric)?.label || currentMetric;
115+
116+
// Create a title for the current metric
117+
const metricTitle = document.createElement("h2");
118+
metricTitle.style.color = "#FFFFFF";
119+
metricTitle.style.textAlign = "center";
120+
metricTitle.style.marginBottom = "20px";
121+
metricTitle.textContent = metricLabel;
122+
container.appendChild(metricTitle);
123+
124+
// Create state grid container
125+
const gridContainer = document.createElement("div");
126+
gridContainer.className = "state-grid";
127+
container.appendChild(gridContainer);
128+
129+
// Create state visualizations
130+
metricData.states.forEach((state, stateIndex) => {
131+
const stateContainer = document.createElement("div");
132+
stateContainer.className = "state-container";
133+
134+
// Create title with metric value
135+
const title = document.createElement("h3");
136+
title.textContent = `Value: ${metricData.values[stateIndex]}`;
137+
stateContainer.appendChild(title);
138+
139+
// Create canvas for state visualization
107140
const canvas = document.createElement("canvas");
108-
canvas.id = `plot-${metric.key}`;
109-
plotDiv.appendChild(canvas);
110-
111-
const chartOptions = structuredClone(CHART_OPTIONS_TEMPLATE);
112-
chartOptions.plugins.title.text = metric.label;
113-
114-
new Chart(canvas, {
115-
type: "bar",
116-
data: {
117-
labels: Array.from(
118-
{ length: data[metric.key].length },
119-
(_, i) => `State ${i + 1}`
120-
),
121-
datasets: [
122-
{
123-
label: metric.label,
124-
data: data[metric.key],
125-
backgroundColor: COLORS[index % COLORS.length],
126-
borderColor: COLORS[index % COLORS.length],
127-
},
128-
],
129-
},
130-
options: chartOptions,
131-
});
132-
});
133-
134-
// Create state visualization grid
135-
const stateGrid = document.createElement("div");
136-
stateGrid.id = "state-grid";
137-
container.appendChild(stateGrid);
138-
139-
// Add state visualizations
140-
data.states.forEach((state, index) => {
141-
const stateDiv = document.createElement("div");
142-
stateDiv.className = "state-container";
143-
stateGrid.appendChild(stateDiv);
144-
145-
const stateTitle = document.createElement("h3");
146-
stateTitle.textContent = `State ${index + 1}`;
147-
stateDiv.appendChild(stateTitle);
141+
stateContainer.appendChild(canvas);
148142

149-
const stateCanvas = document.createElement("canvas");
150-
stateCanvas.id = `state-${index}`;
151-
stateDiv.appendChild(stateCanvas);
143+
// Render state grid
144+
renderStateGrid(state, canvas);
152145

153-
// Render 2D state array
154-
renderStateGrid(state, stateCanvas);
146+
gridContainer.appendChild(stateContainer);
155147
});
156148
}
157149

@@ -171,31 +163,91 @@ function renderStateGrid(state, canvas) {
171163
// Draw state grid
172164
state.forEach((row, y) => {
173165
row.forEach((value, x) => {
174-
ctx.fillStyle = value ? "#FFFFFF" : "#000000";
166+
// Use different colors for different values
167+
let color;
168+
if (value === 0) {
169+
color = "#000000"; // Black for empty
170+
} else if (value === 1) {
171+
color = "#FFFFFF"; // White for walls/obstacles
172+
} else if (value === 2) {
173+
color = "#FF0000"; // Red for player/agent
174+
} else if (value === 3) {
175+
color = "#00FF00"; // Green for goals
176+
} else {
177+
color = "#888888"; // Gray for other values
178+
}
179+
180+
ctx.fillStyle = color;
175181
ctx.fillRect(
176182
x * (cellSize + padding),
177183
y * (cellSize + padding),
178184
cellSize,
179185
cellSize
180186
);
187+
188+
// Add a subtle border around each cell
189+
ctx.strokeStyle = "#333";
190+
ctx.strokeRect(
191+
x * (cellSize + padding),
192+
y * (cellSize + padding),
193+
cellSize,
194+
cellSize
195+
);
181196
});
182197
});
183198
}
184199

185200
// Initialize when document is ready
186201
document.addEventListener("DOMContentLoaded", () => {
187-
// Add controls if they don't exist
188-
const controls = document.createElement("div");
189-
controls.className = "controls";
190-
controls.innerHTML = `
191-
<div class="control">
192-
<label for="n-states">Number of States:</label>
193-
<input type="number" id="n-states" value="10" min="1" max="50">
194-
</div>
195-
<button onclick="updatePlots()">Update Plots</button>
202+
// Create controls if they don't exist
203+
let controls = document.querySelector(".controls");
204+
if (!controls) {
205+
controls = document.createElement("div");
206+
controls.className = "controls";
207+
document.body.insertBefore(controls, document.body.firstChild);
208+
}
209+
210+
// Create metric selector
211+
const metricSelector = document.createElement("select");
212+
metricSelector.id = "metric-selector";
213+
metricSelector.className = "control";
214+
215+
// Add metric options
216+
METRICS.forEach((metric) => {
217+
const option = document.createElement("option");
218+
option.value = metric.key;
219+
option.textContent = metric.label;
220+
metricSelector.appendChild(option);
221+
});
222+
223+
// Add metric selector to controls
224+
const metricControl = document.createElement("div");
225+
metricControl.className = "control";
226+
metricControl.innerHTML = `
227+
<label for="metric-selector">Metric:</label>
228+
${metricSelector.outerHTML}
229+
`;
230+
controls.insertBefore(metricControl, controls.firstChild);
231+
232+
// Add number of states input if it doesn't exist
233+
if (!document.getElementById("n-states")) {
234+
const nStatesControl = document.createElement("div");
235+
nStatesControl.className = "control";
236+
nStatesControl.innerHTML = `
237+
<label for="n-states">Number of States:</label>
238+
<input type="number" id="n-states" value="10" min="1" max="50">
196239
`;
197-
document.body.insertBefore(controls, document.body.firstChild);
240+
controls.appendChild(nStatesControl);
241+
}
242+
243+
// Add update button if it doesn't exist
244+
if (!document.querySelector(".controls button")) {
245+
const updateButton = document.createElement("button");
246+
updateButton.textContent = "Update Plots";
247+
updateButton.onclick = updatePlots;
248+
controls.appendChild(updateButton);
249+
}
198250

199-
// Initial plot update
200-
updatePlots();
251+
// Wait a short moment for DOM to be fully updated before initial plot update
252+
setTimeout(updatePlots, 100);
201253
});

0 commit comments

Comments
 (0)