Skip to content

Commit 3f7d9d5

Browse files
authored
Merge pull request #27 from aerubanov/convergence-diagnostic
add trace plot
2 parents ea5db10 + c218f15 commit 3f7d9d5

File tree

8 files changed

+430
-2
lines changed

8 files changed

+430
-2
lines changed

src/App.css

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
.App-main {
1818
background: var(--color-bg-primary);
19-
overflow: hidden;
19+
overflow-y: auto;
2020
position: relative;
2121
}
2222

src/App.jsx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import './App.css';
22
import Controls from './components/Controls';
33
import Visualizer from './components/Visualizer';
4+
import TracePlots from './components/TracePlots';
45
import useSamplingController from './hooks/useSamplingController';
56

67
function App() {
@@ -36,6 +37,7 @@ function App() {
3637
setUseSecondChain,
3738
setInitialPosition2,
3839
setSeed2,
40+
burnIn,
3941
} = useSamplingController();
4042

4143
return (
@@ -109,6 +111,16 @@ function App() {
109111
acceptedSamples2={samples2}
110112
useSecondChain={useSecondChain}
111113
/>
114+
<div className="trace-plots-section">
115+
{contourData && (
116+
<TracePlots
117+
samples={samples}
118+
samples2={samples2}
119+
burnIn={burnIn}
120+
useSecondChain={useSecondChain}
121+
/>
122+
)}
123+
</div>
112124
</div>
113125
</div>
114126
);

src/components/TracePlots.css

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
.trace-plots-container {
2+
display: flex;
3+
flex-direction: column;
4+
gap: 20px;
5+
width: 100%;
6+
margin-top: 20px;
7+
padding: 10px;
8+
background-color: #ffffff;
9+
border-radius: 8px;
10+
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
11+
}
12+
13+
.trace-plot-wrapper {
14+
width: 100%;
15+
display: flex;
16+
flex-direction: column;
17+
align-items: center;
18+
}
19+
20+
.trace-title {
21+
margin: 0 0 10px 0;
22+
font-size: 14px;
23+
color: #1a1a1a;
24+
font-weight: 600;
25+
align-self: flex-start;
26+
padding-left: 10px;
27+
}

src/components/TracePlots.jsx

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import './TracePlots.css';
2+
import Plot from 'react-plotly.js';
3+
import PropTypes from 'prop-types';
4+
import { TRACE_PLOT, HMC_SAMPLER } from '../utils/plotConfig.json';
5+
import { createTracePlotTrace } from '../utils/plotFunctions';
6+
7+
function TracePlots({ samples, samples2, burnIn, useSecondChain }) {
8+
const commonLayout = {
9+
...TRACE_PLOT.layout,
10+
showlegend: true,
11+
legend: {
12+
orientation: 'h',
13+
y: -0.2, // Move legend below plot
14+
},
15+
};
16+
17+
const xConfig = {
18+
displayModeBar: false,
19+
responsive: true,
20+
};
21+
22+
// Generate traces for X
23+
const xTraces = [];
24+
if (samples && samples.length > 0) {
25+
xTraces.push(
26+
...createTracePlotTrace(
27+
samples,
28+
'x',
29+
burnIn,
30+
HMC_SAMPLER.styles.primaryColor,
31+
'Chain 1'
32+
)
33+
);
34+
}
35+
36+
if (useSecondChain && samples2 && samples2.length > 0) {
37+
xTraces.push(
38+
...createTracePlotTrace(
39+
samples2,
40+
'x',
41+
burnIn,
42+
HMC_SAMPLER.styles.secondaryColor,
43+
'Chain 2'
44+
)
45+
);
46+
}
47+
48+
// Generate traces for Y
49+
const yTraces = [];
50+
if (samples && samples.length > 0) {
51+
yTraces.push(
52+
...createTracePlotTrace(
53+
samples,
54+
'y',
55+
burnIn,
56+
HMC_SAMPLER.styles.primaryColor,
57+
'Chain 1'
58+
)
59+
);
60+
}
61+
62+
if (useSecondChain && samples2 && samples2.length > 0) {
63+
yTraces.push(
64+
...createTracePlotTrace(
65+
samples2,
66+
'y',
67+
burnIn,
68+
HMC_SAMPLER.styles.secondaryColor,
69+
'Chain 2'
70+
)
71+
);
72+
}
73+
74+
return (
75+
<div className="trace-plots-container">
76+
<div className="trace-plot-wrapper">
77+
<h4 className="trace-title">X Trace</h4>
78+
<Plot
79+
data={xTraces}
80+
layout={{ ...commonLayout, title: '' }} // Remove title from Plotly, use HTML headers
81+
config={xConfig}
82+
style={{ width: '100%', height: '300px' }}
83+
useResizeHandler={true}
84+
/>
85+
</div>
86+
<div className="trace-plot-wrapper">
87+
<h4 className="trace-title">Y Trace</h4>
88+
<Plot
89+
data={yTraces}
90+
layout={{ ...commonLayout, title: '' }}
91+
config={xConfig}
92+
style={{ width: '100%', height: '300px' }}
93+
useResizeHandler={true}
94+
/>
95+
</div>
96+
</div>
97+
);
98+
}
99+
100+
TracePlots.propTypes = {
101+
samples: PropTypes.arrayOf(
102+
PropTypes.shape({
103+
x: PropTypes.number.isRequired,
104+
y: PropTypes.number.isRequired,
105+
})
106+
),
107+
samples2: PropTypes.arrayOf(
108+
PropTypes.shape({
109+
x: PropTypes.number.isRequired,
110+
y: PropTypes.number.isRequired,
111+
})
112+
),
113+
burnIn: PropTypes.number,
114+
useSecondChain: PropTypes.bool,
115+
};
116+
117+
export default TracePlots;

src/hooks/useSamplingController.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ export default function useSamplingController() {
3535
const [rejectedCount2, setRejectedCount2] = useState(0);
3636
const [seed2, setSeed2State] = useState(null);
3737

38+
// Visualization params
39+
const [burnIn] = useState(10);
40+
3841
// Refs to hold instances/values that don't trigger re-renders or need to be accessed in loops
3942
const logpInstanceRef = useRef(null);
4043
const currentParticleRef = useRef(null); // { q, p }
@@ -336,5 +339,6 @@ export default function useSamplingController() {
336339
setUseSecondChain,
337340
setInitialPosition2,
338341
setSeed2,
342+
burnIn,
339343
};
340344
}

src/utils/plotConfig.json

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,34 @@
3636
}
3737
}
3838
},
39+
"TRACE_PLOT": {
40+
"styles": {
41+
"burnInOpacity": 0.3,
42+
"lineWidth": 2
43+
},
44+
"layout": {
45+
"font": {
46+
"color": "#1a1a1a",
47+
"family": "Inter, system-ui, sans-serif",
48+
"size": 10
49+
},
50+
"xaxis": {
51+
"title": "Iteration",
52+
"showgrid": false,
53+
"zeroline": false
54+
},
55+
"yaxis": {
56+
"showgrid": true,
57+
"zeroline": false
58+
},
59+
"margin": {
60+
"l": 50,
61+
"r": 20,
62+
"t": 30,
63+
"b": 40
64+
}
65+
}
66+
},
3967
"GENERAL": {
4068
"layout": {
4169
"autosize": true,

src/utils/plotFunctions.js

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { CONTOUR, HMC_SAMPLER } from './plotConfig.json';
1+
import { CONTOUR, HMC_SAMPLER, TRACE_PLOT } from './plotConfig.json';
22

33
/**
44
* Creates a Plotly contour trace configuration
@@ -142,3 +142,94 @@ export function createSamplesTrace(
142142
hovertemplate: 'Sample<br>x: %{x:.2f}<br>y: %{y:.2f}<extra></extra>',
143143
};
144144
}
145+
146+
/**
147+
* Converts a hex color to rgba string
148+
* @param {string} hex - Hex color string (e.g., "#ff0000")
149+
* @param {number} alpha - Alpha value (0-1)
150+
* @returns {string} Rgba color string
151+
*/
152+
function hexToRgba(hex, alpha) {
153+
const r = parseInt(hex.slice(1, 3), 16);
154+
const g = parseInt(hex.slice(3, 5), 16);
155+
const b = parseInt(hex.slice(5, 7), 16);
156+
return `rgba(${r}, ${g}, ${b}, ${alpha})`;
157+
}
158+
159+
/**
160+
* Creates Plotly traces for trace plots (iteration vs value)
161+
* @param {Array<{x: number, y: number}>} samples - Array of accepted sample points
162+
* @param {string} axis - 'x' or 'y' to plot
163+
* @param {number} burnIn - Number of samples to treat as burn-in
164+
* @param {string} [color] - Color for the valid samples
165+
* @param {string} [name] - Name for the valid samples trace
166+
* @returns {object[]} Array of Plotly trace objects (burn-in and valid)
167+
*/
168+
export function createTracePlotTrace(
169+
samples,
170+
axis,
171+
burnIn = 0,
172+
color = HMC_SAMPLER.styles.primaryColor,
173+
name = 'Trace'
174+
) {
175+
if (!samples || !Array.isArray(samples) || samples.length === 0) {
176+
return [];
177+
}
178+
179+
const traces = [];
180+
const validOpacity = 1.0;
181+
const burnInOpacity = TRACE_PLOT.styles.burnInOpacity;
182+
const lineWidth = TRACE_PLOT.styles.lineWidth;
183+
184+
// Split samples into burn-in and valid
185+
let burnInSamples = [];
186+
let validSamples = [];
187+
188+
if (burnIn > 0) {
189+
// If we have valid samples after burn-in, include the first one in burn-in set to connect lines
190+
const endIndex = samples.length > burnIn ? burnIn + 1 : burnIn;
191+
burnInSamples = samples.slice(0, endIndex);
192+
validSamples = samples.slice(burnIn);
193+
} else {
194+
validSamples = samples;
195+
}
196+
197+
// Helper to create a single trace part
198+
const createSubTrace = (data, startIndex, opacity, traceName, showLegend) => {
199+
const iterations = data.map((_, i) => i + startIndex);
200+
const values = data.map((p) => p[axis]);
201+
202+
// Use RGBA for color to ensure opacity works reliably on lines
203+
const traceColor = opacity < 1 ? hexToRgba(color, opacity) : color;
204+
205+
return {
206+
type: 'scatter',
207+
mode: 'lines',
208+
x: iterations,
209+
y: values,
210+
line: {
211+
color: traceColor,
212+
width: lineWidth,
213+
},
214+
// Remove top-level opacity to rely on rgba color
215+
// opacity: opacity,
216+
name: traceName,
217+
showlegend: showLegend,
218+
hovertemplate: `Iter: %{x}<br>${axis}: %{y:.2f}<extra></extra>`,
219+
};
220+
};
221+
222+
// Add burn-in trace
223+
if (burnInSamples.length > 0) {
224+
traces.push(
225+
createSubTrace(burnInSamples, 0, burnInOpacity, `${name} (Burn-in)`, true)
226+
);
227+
}
228+
229+
// Add valid samples trace
230+
if (validSamples.length > 0) {
231+
traces.push(createSubTrace(validSamples, burnIn, validOpacity, name, true));
232+
}
233+
234+
return traces;
235+
}

0 commit comments

Comments
 (0)