Skip to content

Commit 3cb0ab9

Browse files
[FIX] misc fixes for triton-puzzles (#252)
1 parent c09680f commit 3cb0ab9

31 files changed

+4609
-8
lines changed

.gitignore

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ coverage.xml
5454
.hypothesis/
5555
.pytest_cache/
5656
cover/
57-
triton_viz/static/
5857

5958
# Translations
6059
*.mo
@@ -170,4 +169,3 @@ uv.lock
170169
triton_viz/version.py
171170
.subagents/
172171
subagent*.txt
173-
triton_viz/templates/index.html

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ If you want to run tests, run `uv sync --extra test` instead of `uv sync`. Other
6767

6868
### Frontend Build
6969

70-
If you want to run the visualizer, build the TS sources:
70+
The PyPI package ships with prebuilt frontend assets in `triton_viz/static`, so
71+
you do not need npm to run the visualizer. If you want to modify the frontend,
72+
rebuild the TS sources:
7173

7274
```sh
7375
npm install

frontend/assets/visualizer.css

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,21 @@ body {
708708
color: var(--text-secondary);
709709
}
710710

711+
/* webgl fallback banner */
712+
.webgl-warning {
713+
position: absolute;
714+
inset: 24px;
715+
display: flex;
716+
align-items: center;
717+
justify-content: center;
718+
text-align: center;
719+
font-size: 14px;
720+
color: #f8fafc;
721+
border-color: var(--warning);
722+
background: rgba(15, 23, 42, 0.88);
723+
box-shadow: var(--shadow-soft);
724+
}
725+
711726
.is-draggable {
712727
cursor: grab;
713728
user-select: none;

frontend/components/tensor_view.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import {
1717
GAP,
1818
COLOR_HOVER,
1919
updateTensorHighlights,
20+
canUseWebgl,
21+
renderWebglWarning,
2022
} from '../utils/three_utils.js';
2123
import { createHistogramOverlay } from './histogram.js';
2224
import { enableDrag } from '../utils/ui_helpers.js';
@@ -852,14 +854,26 @@ export function createTensorVisualization(
852854
const stage = document.createElement('div');
853855
stage.className = 'viz-stage';
854856
containerElement.appendChild(stage);
857+
if (!canUseWebgl()) {
858+
// show a visible message when WebGL is disabled and skip initialization.
859+
return renderWebglWarning(containerElement);
860+
}
855861
const sideMenu = createSideMenu(containerElement);
856862
const histogramUI = createHistogramOverlay(containerElement, {
857863
title: `${type} Value Distribution`,
858864
apiBase: API_BASE,
859865
sources: configs.map(c => ({ value: c.name.toUpperCase(), label: `${c.name} Tensor` })),
860866
buildRequestBody: (s, b) => ({ uuid: op.uuid, source: s, bins: b }),
861867
});
862-
const { scene, camera, renderer } = setupScene(stage, 0x000000);
868+
let scene: ThreeScene;
869+
let camera: ThreeCamera;
870+
let renderer: ThreeRenderer;
871+
try {
872+
({ scene, camera, renderer } = setupScene(stage, 0x000000));
873+
} catch (err) {
874+
// webgl can still fail even after a feature test.
875+
return renderWebglWarning(containerElement);
876+
}
863877
const disposer = createDisposer();
864878
const { cubeGeometry, edgesGeometry, lineMaterial } = setupGeometries();
865879
const tensors = new Map<string, TensorGroup>();

frontend/core/api.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ export type RequestOptions = {
99
const getDefaultBase = (): string => {
1010
if (typeof globalThis === 'undefined') return '';
1111
const globalBase = (globalThis as typeof globalThis & { __TRITON_VIZ_API__?: string }).__TRITON_VIZ_API__;
12-
return globalBase || '';
12+
if (globalBase) return globalBase;
13+
if (typeof window === 'undefined' || !window.location) return '';
14+
// derive a stable base from the current page location for proxy paths
15+
const baseUrl = new URL('.', window.location.href);
16+
return baseUrl.href.replace(/\/$/, '');
1317
};
1418

1519
const buildUrl = (path?: string | null, baseOverride?: string | null): string => {

frontend/utils/three_utils.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,29 @@ export const COLOR_EDGE = new THREE.Color(0.5, 0.5, 0.5);
2222

2323
const COLOR_SLICE = new THREE.Color(0.0, 0.7, 1.0);
2424

25+
// quick feature test for WebGL availability in the current browser.
26+
export function canUseWebgl(): boolean {
27+
if (typeof document === 'undefined') return false;
28+
const canvas = document.createElement('canvas');
29+
const gl = canvas.getContext('webgl2') || canvas.getContext('webgl') || canvas.getContext('experimental-webgl');
30+
return !!gl;
31+
}
32+
33+
// render a visible warning when WebGL is missing or disabled.
34+
export function renderWebglWarning(container: HTMLElement): () => void {
35+
const existing = container.querySelector('.webgl-warning');
36+
if (existing) {
37+
return () => existing.remove();
38+
}
39+
const warning = document.createElement('div');
40+
warning.className = 'webgl-warning info-card';
41+
warning.setAttribute('role', 'alert');
42+
warning.setAttribute('aria-live', 'assertive');
43+
warning.textContent = 'WebGL is required for the 3D visualizer. Enable WebGL in your browser settings and reload this page.';
44+
container.appendChild(warning);
45+
return () => warning.remove();
46+
}
47+
2548
export function setupScene(container: HTMLElement, backgroundColor = 0x000000): {
2649
scene: ThreeScene;
2750
camera: ThreeCamera;

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dependencies = [
3434
]
3535

3636
[tool.setuptools.packages.find]
37-
include=["triton_viz"]
37+
include=["triton_viz*"]
3838
exclude=["tasks"]
3939

4040
[tool.setuptools]
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import { postJson } from "../core/api.js";
2+
import { createDisposer } from "../utils/dispose.js";
3+
export function createHistogramOverlay(containerElement, options) {
4+
const { title = "Value Distribution", sources = [], apiBase = "", buildRequestBody, defaultBins = 64, } = options;
5+
if (!buildRequestBody) {
6+
throw new Error("buildRequestBody is required for histogram overlay");
7+
}
8+
const button = document.createElement("button");
9+
button.textContent = "Value Histogram";
10+
const overlay = document.createElement("div");
11+
Object.assign(overlay.style, {
12+
position: "fixed",
13+
top: "50px",
14+
right: "50px",
15+
width: "480px",
16+
padding: "12px",
17+
background: "rgba(0, 0, 0, 0.85)",
18+
color: "#fff",
19+
borderRadius: "8px",
20+
border: "1px solid #555",
21+
zIndex: 3000,
22+
display: "none",
23+
});
24+
const header = document.createElement("div");
25+
header.textContent = title;
26+
header.style.fontSize = "16px";
27+
header.style.marginBottom = "8px";
28+
header.style.fontWeight = "bold";
29+
overlay.appendChild(header);
30+
const controls = document.createElement("div");
31+
controls.style.display = "flex";
32+
controls.style.gap = "12px";
33+
controls.style.flexWrap = "wrap";
34+
controls.style.alignItems = "flex-end";
35+
const sourceGroup = document.createElement("div");
36+
sourceGroup.style.display = "flex";
37+
sourceGroup.style.flexDirection = "column";
38+
sourceGroup.style.gap = "4px";
39+
const sourceLabel = document.createElement("label");
40+
sourceLabel.textContent = "Activation";
41+
sourceLabel.style.fontSize = "12px";
42+
sourceLabel.style.opacity = "0.8";
43+
sourceGroup.appendChild(sourceLabel);
44+
const select = document.createElement("select");
45+
select.id = "histogram-source";
46+
sources.forEach((src) => {
47+
const opt = document.createElement("option");
48+
opt.value = src.value;
49+
opt.textContent = src.label;
50+
select.appendChild(opt);
51+
});
52+
if (!select.value && select.options.length) {
53+
select.selectedIndex = 0;
54+
}
55+
sourceLabel.htmlFor = select.id;
56+
sourceGroup.appendChild(select);
57+
controls.appendChild(sourceGroup);
58+
const binsGroup = document.createElement("div");
59+
binsGroup.style.display = "flex";
60+
binsGroup.style.flexDirection = "column";
61+
binsGroup.style.gap = "4px";
62+
const binsLabel = document.createElement("label");
63+
binsLabel.textContent = "Bins";
64+
binsLabel.style.fontSize = "12px";
65+
binsLabel.style.opacity = "0.8";
66+
binsGroup.appendChild(binsLabel);
67+
const binInput = document.createElement("input");
68+
binInput.type = "number";
69+
binInput.value = String(defaultBins);
70+
binInput.min = "4";
71+
binInput.max = "512";
72+
binInput.step = "2";
73+
binInput.style.width = "80px";
74+
binInput.title = "Number of bins";
75+
binInput.id = "histogram-bins";
76+
binsLabel.htmlFor = binInput.id;
77+
binsGroup.appendChild(binInput);
78+
controls.appendChild(binsGroup);
79+
overlay.appendChild(controls);
80+
const info = document.createElement("div");
81+
info.style.margin = "6px 0";
82+
info.style.fontSize = "12px";
83+
overlay.appendChild(info);
84+
const canvas = document.createElement("canvas");
85+
canvas.width = 440;
86+
canvas.height = 240;
87+
canvas.style.background = "#111";
88+
canvas.style.border = "1px solid #444";
89+
overlay.appendChild(canvas);
90+
const status = document.createElement("div");
91+
status.style.fontSize = "12px";
92+
status.style.marginTop = "4px";
93+
overlay.appendChild(status);
94+
function show() {
95+
overlay.style.display = "block";
96+
updateHistogram();
97+
}
98+
function hide() {
99+
overlay.style.display = "none";
100+
}
101+
const disposer = createDisposer();
102+
disposer.listen(button, "click", show);
103+
disposer.listen(select, "change", () => {
104+
updateHistogram();
105+
});
106+
disposer.listen(binInput, "input", () => {
107+
updateHistogram();
108+
});
109+
async function updateHistogram() {
110+
status.textContent = "Loading histogram...";
111+
info.textContent = "";
112+
const bins = parseInt(binInput.value, 10) || defaultBins;
113+
if (!select.value && select.options.length) {
114+
select.selectedIndex = 0;
115+
}
116+
try {
117+
const body = buildRequestBody(select.value, bins);
118+
body.bins = bins;
119+
body.max_samples = body.max_samples || 200000;
120+
const data = await postJson("/api/histogram", body, { base: apiBase });
121+
drawHistogram(canvas, data.counts, data.edges);
122+
info.textContent = `Min: ${data.min.toFixed(6)} | Max: ${data.max.toFixed(6)} | Total values: ${data.n} | Sampled: ${data.sampled}`;
123+
status.textContent = "";
124+
}
125+
catch (err) {
126+
const message = err instanceof Error ? err.message : String(err);
127+
status.textContent = `Histogram error: ${message}`;
128+
const ctx = canvas.getContext("2d");
129+
if (ctx)
130+
ctx.clearRect(0, 0, canvas.width, canvas.height);
131+
}
132+
}
133+
function drawHistogram(canvasEl, counts, edges) {
134+
const ctx = canvasEl.getContext("2d");
135+
if (!ctx)
136+
return;
137+
ctx.clearRect(0, 0, canvasEl.width, canvasEl.height);
138+
if (!counts || !counts.length || !edges || edges.length < 2) {
139+
ctx.fillStyle = "#888";
140+
ctx.fillText("No data", 20, 30);
141+
return;
142+
}
143+
const width = canvasEl.width - 40;
144+
const height = canvasEl.height - 40;
145+
const originX = 30;
146+
const originY = canvasEl.height - 20;
147+
const maxCount = Math.max(...counts);
148+
const barWidth = width / counts.length;
149+
ctx.strokeStyle = "#555";
150+
ctx.beginPath();
151+
ctx.moveTo(originX, originY);
152+
ctx.lineTo(originX + width, originY);
153+
ctx.stroke();
154+
counts.forEach((count, idx) => {
155+
const barHeight = maxCount ? (count / maxCount) * height : 0;
156+
const x = originX + idx * barWidth;
157+
const y = originY - barHeight;
158+
ctx.fillStyle = "#ffa500";
159+
ctx.fillRect(x, y, Math.max(1, barWidth - 2), barHeight);
160+
});
161+
ctx.fillStyle = "#ccc";
162+
ctx.font = "10px monospace";
163+
const firstEdge = edges[0];
164+
const lastEdge = edges[edges.length - 1];
165+
if (firstEdge === undefined || lastEdge === undefined)
166+
return;
167+
ctx.fillText(`${firstEdge.toFixed(4)}`, originX, originY + 12);
168+
ctx.fillText(`${lastEdge.toFixed(4)}`, originX + width - 40, originY + 12);
169+
}
170+
(containerElement || document.body).appendChild(overlay);
171+
return {
172+
button,
173+
overlay,
174+
show,
175+
hide,
176+
destroy() {
177+
disposer.dispose();
178+
overlay.remove();
179+
},
180+
};
181+
}

0 commit comments

Comments
 (0)