Skip to content

Commit defd9ab

Browse files
committed
ens
1 parent 8c7a750 commit defd9ab

File tree

3 files changed

+408
-14
lines changed

3 files changed

+408
-14
lines changed

nbs/ensemble/widget.css

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
.tree-widget {
2+
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
3+
background: #fafafa;
4+
border: 1px solid #e0e0e0;
5+
border-radius: 6px;
6+
}
7+
8+
.tree-info-panel {
9+
padding: 6px 10px;
10+
font-size: 13px;
11+
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
12+
color: #555;
13+
border-top: 1px solid #e0e0e0;
14+
}
15+
16+
.tree-edge {
17+
stroke: #bbb;
18+
stroke-width: 1;
19+
transition: stroke 0.15s, stroke-width 0.15s;
20+
}
21+
.tree-edge.highlighted {
22+
stroke: #2563eb;
23+
stroke-width: 2;
24+
}
25+
26+
.tree-node rect {
27+
fill: #fff;
28+
stroke: #999;
29+
stroke-width: 1;
30+
transition: fill 0.15s, stroke 0.15s;
31+
}
32+
.tree-leaf rect {
33+
fill: #f0f4ff;
34+
}
35+
36+
.tree-node.on-path rect {
37+
fill: #dbeafe;
38+
stroke: #2563eb;
39+
stroke-width: 1.5;
40+
}
41+
.tree-node.selected rect {
42+
fill: #2563eb;
43+
stroke: #1d4ed8;
44+
stroke-width: 2;
45+
}
46+
.tree-node:hover rect {
47+
stroke: #2563eb;
48+
stroke-width: 1.5;
49+
}
50+
51+
/* Dark mode — triggered by .dark class on wrapper */
52+
.dark .tree-widget {
53+
background: #1a1a2e;
54+
border-color: #333;
55+
}
56+
.dark .tree-info-panel {
57+
color: #aaa;
58+
border-top-color: #333;
59+
}
60+
.dark .tree-edge {
61+
stroke: #555;
62+
}
63+
.dark .tree-edge.highlighted {
64+
stroke: #60a5fa;
65+
}
66+
.dark .tree-node rect {
67+
fill: #2a2a3e;
68+
stroke: #555;
69+
}
70+
.dark .tree-leaf rect {
71+
fill: #1e293b;
72+
}
73+
.dark .tree-node.on-path rect {
74+
fill: #1e3a5f;
75+
stroke: #60a5fa;
76+
}
77+
.dark .tree-node.selected rect {
78+
fill: #2563eb;
79+
stroke: #3b82f6;
80+
}

nbs/ensemble/widget.js

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
function render({ model, el }) {
2+
function draw() {
3+
el.innerHTML = "";
4+
const data = model.get("tree_data");
5+
if (!data || !data.nodes || data.nodes.length === 0) return;
6+
7+
const DISPLAY_W = 700;
8+
const DISPLAY_H = 400;
9+
10+
// Detect dark mode from marimo/document context
11+
const isDark =
12+
document.documentElement.dataset.colorMode === "dark" ||
13+
document.documentElement.classList.contains("dark") ||
14+
document.body.classList.contains("dark");
15+
16+
const wrapper = document.createElement("div");
17+
wrapper.classList.add("tree-wrapper");
18+
if (isDark) wrapper.classList.add("dark");
19+
el.appendChild(wrapper);
20+
21+
const svg = document.createElementNS("http://www.w3.org/2000/svg", "svg");
22+
svg.setAttribute("width", DISPLAY_W);
23+
svg.setAttribute("height", DISPLAY_H);
24+
svg.setAttribute("viewBox", `0 0 ${data.width} ${data.height}`);
25+
svg.setAttribute("preserveAspectRatio", "xMidYMid meet");
26+
svg.classList.add("tree-widget");
27+
wrapper.appendChild(svg);
28+
29+
// Info panel below SVG
30+
const panel = document.createElement("div");
31+
panel.classList.add("tree-info-panel");
32+
panel.textContent = "Click a node to inspect it.";
33+
wrapper.appendChild(panel);
34+
35+
// Build lookups
36+
const nodeMap = {};
37+
data.nodes.forEach((n) => {
38+
nodeMap[n.id] = n;
39+
});
40+
41+
// Build parent map from edges for path computation in JS
42+
const parentOf = {};
43+
data.edges.forEach((edge) => {
44+
parentOf[edge.target] = edge.source;
45+
});
46+
47+
// Compute path from root to a node entirely in JS
48+
function pathToNode(nodeId) {
49+
if (nodeId < 0) return [];
50+
const path = [nodeId];
51+
let current = nodeId;
52+
while (parentOf[current] !== undefined) {
53+
current = parentOf[current];
54+
path.push(current);
55+
}
56+
path.reverse();
57+
return path;
58+
}
59+
60+
// Draw edges
61+
const edgeEls = {};
62+
data.edges.forEach((edge) => {
63+
const src = nodeMap[edge.source];
64+
const tgt = nodeMap[edge.target];
65+
const line = document.createElementNS("http://www.w3.org/2000/svg", "line");
66+
line.setAttribute("x1", src.x);
67+
line.setAttribute("y1", src.y);
68+
line.setAttribute("x2", tgt.x);
69+
line.setAttribute("y2", tgt.y);
70+
line.classList.add("tree-edge");
71+
svg.appendChild(line);
72+
edgeEls[`${edge.source}-${edge.target}`] = line;
73+
});
74+
75+
// Draw nodes as small circles
76+
const nodeEls = {};
77+
const nodeR = 4;
78+
data.nodes.forEach((node) => {
79+
const circle = document.createElementNS("http://www.w3.org/2000/svg", "circle");
80+
circle.setAttribute("cx", node.x);
81+
circle.setAttribute("cy", node.y);
82+
circle.setAttribute("r", nodeR);
83+
circle.classList.add("tree-node");
84+
if (node.is_leaf) circle.classList.add("tree-leaf");
85+
svg.appendChild(circle);
86+
nodeEls[node.id] = circle;
87+
});
88+
89+
// Use SVG's native coordinate transform (guaranteed correct with viewBox)
90+
function toSVG(e) {
91+
const pt = svg.createSVGPoint();
92+
pt.x = e.clientX;
93+
pt.y = e.clientY;
94+
const svgPt = pt.matrixTransform(svg.getScreenCTM().inverse());
95+
return { x: svgPt.x, y: svgPt.y };
96+
}
97+
98+
// Find nearest node to a point
99+
function nearestNode(px, py, maxDist) {
100+
let best = null;
101+
let bestDist = maxDist * maxDist;
102+
data.nodes.forEach((n) => {
103+
const dx = n.x - px;
104+
const dy = n.y - py;
105+
const d2 = dx * dx + dy * dy;
106+
if (d2 < bestDist) {
107+
bestDist = d2;
108+
best = n;
109+
}
110+
});
111+
return best;
112+
}
113+
114+
// Show node info in panel
115+
function showInfo(node, prefix) {
116+
if (!node) {
117+
panel.textContent = "Click a node to inspect it.";
118+
return;
119+
}
120+
const p = prefix ? prefix + " " : "";
121+
if (node.is_leaf) {
122+
panel.textContent = `${p}Leaf — samples: ${node.samples}`;
123+
} else {
124+
panel.textContent = `${p}Split: ${node.label} — samples: ${node.samples}`;
125+
}
126+
}
127+
128+
// Apply highlight for a given path + selected node (no Python needed)
129+
let currentPath = [];
130+
function applyHighlight(selectedId) {
131+
currentPath = pathToNode(selectedId);
132+
const pathSet = new Set(currentPath);
133+
134+
const pathEdgeKeys = new Set();
135+
for (let i = 0; i < currentPath.length - 1; i++) {
136+
pathEdgeKeys.add(`${currentPath[i]}-${currentPath[i + 1]}`);
137+
}
138+
139+
Object.entries(edgeEls).forEach(([key, line]) => {
140+
line.classList.toggle("highlighted", pathEdgeKeys.has(key));
141+
});
142+
143+
Object.entries(nodeEls).forEach(([id, circle]) => {
144+
const nid = parseInt(id);
145+
circle.classList.toggle("on-path", pathSet.has(nid));
146+
circle.classList.toggle("selected", nid === selectedId);
147+
});
148+
149+
if (selectedId >= 0 && nodeMap[selectedId]) {
150+
showInfo(nodeMap[selectedId], "Selected");
151+
} else {
152+
showInfo(null);
153+
}
154+
}
155+
156+
// Hover: show info for nearest node
157+
svg.addEventListener("mousemove", (e) => {
158+
const pt = toSVG(e);
159+
const node = nearestNode(pt.x, pt.y, 20);
160+
if (node) {
161+
svg.style.cursor = "pointer";
162+
const sel = model.get("selected_node");
163+
if (sel < 0 || node.id !== sel) {
164+
showInfo(node, "");
165+
}
166+
} else {
167+
svg.style.cursor = "";
168+
const sel = model.get("selected_node");
169+
if (sel >= 0 && nodeMap[sel]) {
170+
showInfo(nodeMap[sel], "Selected");
171+
} else {
172+
showInfo(null);
173+
}
174+
}
175+
});
176+
177+
// Click: select nearest node, highlight immediately in JS, then sync to Python
178+
svg.addEventListener("click", (e) => {
179+
const pt = toSVG(e);
180+
const node = nearestNode(pt.x, pt.y, 20);
181+
if (node) {
182+
// Highlight immediately — no Python roundtrip needed
183+
applyHighlight(node.id);
184+
model.set("selected_node", node.id);
185+
model.save_changes();
186+
} else {
187+
applyHighlight(-1);
188+
model.set("selected_node", -1);
189+
model.save_changes();
190+
}
191+
});
192+
193+
// Also handle changes from Python side (e.g. programmatic selection)
194+
model.on("change:selected_node", () => {
195+
applyHighlight(model.get("selected_node"));
196+
});
197+
198+
// Initial state
199+
applyHighlight(model.get("selected_node"));
200+
}
201+
202+
draw();
203+
model.on("change:tree_data", draw);
204+
}
205+
206+
export default { render };

0 commit comments

Comments
 (0)