Skip to content

Commit 0b7ec9b

Browse files
authored
Merge pull request #28 from NGO-Algorithm-Audit/feature/violin-lot-poc
Feature/violin lot poc
2 parents 93e13a9 + 7653011 commit 0b7ec9b

File tree

2 files changed

+366
-1
lines changed

2 files changed

+366
-1
lines changed

src/components/DistributionReport.tsx

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import HeatMapChart from './graphs/HeatMap';
77
import { UnivariateCharts } from './UnivariateCharts';
88
import { Accordion } from './ui/accordion';
99
import { createHeatmapdata } from './createHeatmapdata';
10+
import ViolinChart from './graphs/ViolinChart';
1011

1112
interface DistributionReport {
1213
reportType: string;
@@ -30,6 +31,10 @@ export const DistributionReport = (
3031
const syntheticData = JSON.parse(distributionReportProps.synthetic);
3132
const dataTypes = JSON.parse(distributionReportProps.dataTypes);
3233
console.log('reports', distributionReportProps.reports);
34+
35+
const columnNames = Object.keys(realData[0]).filter(column => {
36+
return column != 'realOrSynthetic';
37+
});
3338
return (
3439
<div className="flex flex-col gap-6">
3540
{distributionReportProps.reports.map(
@@ -92,13 +97,37 @@ export const DistributionReport = (
9297
report.reportType ===
9398
'bivariateDistributionSyntheticData'
9499
) {
100+
const charts = columnNames.map(column => {
101+
const dataType = dataTypes[column];
102+
return columnNames.map(column2 => {
103+
if (column === column2) {
104+
return null;
105+
}
106+
const dataType2 = dataTypes[column2];
107+
if (
108+
dataType === 'float' &&
109+
dataType2 === 'category'
110+
) {
111+
return (
112+
<ViolinChart
113+
key={column + column2}
114+
categoricalColumn={column2}
115+
numericColumn={column}
116+
realData={realData}
117+
syntheticData={syntheticData}
118+
/>
119+
);
120+
}
121+
return null;
122+
});
123+
});
95124
return (
96125
<div key={indexReport} className="mb-4">
97126
<Accordion
98127
title={t(
99128
'syntheticData.bivariateDistributionSyntheticData'
100129
)}
101-
content={<p>PLACEHOLDER</p>}
130+
content={<div>{charts}</div>}
102131
/>
103132
</div>
104133
);
Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
import { useEffect, useRef, useState } from 'react';
2+
import * as d3 from 'd3';
3+
import { useTranslation } from 'react-i18next';
4+
5+
interface ViolinChartProps {
6+
categoricalColumn: string;
7+
numericColumn: string;
8+
realData: Array<{ [key: string]: any }>;
9+
syntheticData: Array<{ [key: string]: any }>;
10+
}
11+
12+
const margin = { top: 30, right: 50, bottom: 60, left: 80 };
13+
const height = 380 - margin.top - margin.bottom;
14+
15+
const ViolinChart = ({
16+
categoricalColumn,
17+
numericColumn,
18+
realData,
19+
syntheticData,
20+
}: ViolinChartProps) => {
21+
const svgRef = useRef<SVGSVGElement>(null);
22+
const containerRef = useRef<HTMLDivElement>(null);
23+
const [containerWidth, setContainerWidth] = useState(800);
24+
const { t } = useTranslation();
25+
26+
useEffect(() => {
27+
const plotHeight = height - margin.top - margin.bottom;
28+
29+
// Get unique categories
30+
const categories = Array.from(
31+
new Set([
32+
...realData.map(d => d[categoricalColumn]),
33+
...syntheticData.map(d => d[categoricalColumn]),
34+
])
35+
);
36+
37+
// Process data for violin plots
38+
const violinData = categories.map(category => {
39+
const realValues = realData
40+
.filter(d => d[categoricalColumn] === category)
41+
.map(d => +d[numericColumn]);
42+
const syntheticValues = syntheticData
43+
.filter(d => d[categoricalColumn] === category)
44+
.map(d => +d[numericColumn]);
45+
46+
return {
47+
category,
48+
real: realValues,
49+
synthetic: syntheticValues,
50+
};
51+
});
52+
53+
// Clear previous content
54+
d3.select(svgRef.current).selectAll('*').remove();
55+
56+
// Create SVG
57+
const svg = d3
58+
.select(svgRef.current)
59+
.attr('width', containerWidth)
60+
.attr('height', height + margin.top + margin.bottom)
61+
.append('g')
62+
.attr('transform', `translate(${margin.left},${margin.top})`);
63+
64+
// Reserve space for legend (120px width + 20px padding)
65+
const legendWidth = 140;
66+
const plotWidth =
67+
containerWidth - margin.left - margin.right - legendWidth;
68+
69+
// Create scales
70+
const xScale = d3
71+
.scaleBand()
72+
.domain(categories)
73+
.range([0, plotWidth])
74+
.padding(0.1);
75+
76+
// Limit the bandwidth to 200px max (100px per side)
77+
const bandwidth = Math.min(xScale.bandwidth() / 2, 50);
78+
79+
// Calculate min and max with 10% padding on top
80+
const minValue =
81+
d3.min([
82+
...realData.map(d => +d[numericColumn]),
83+
...syntheticData.map(d => +d[numericColumn]),
84+
]) || 0;
85+
const maxValue =
86+
d3.max([
87+
...realData.map(d => +d[numericColumn]),
88+
...syntheticData.map(d => +d[numericColumn]),
89+
]) || 0;
90+
const paddedMaxValue = maxValue + (maxValue - minValue) * 0.1;
91+
92+
const yScale = d3
93+
.scaleLinear()
94+
.domain([minValue, paddedMaxValue])
95+
.range([plotHeight, 0]);
96+
97+
// Create violin plot for each category
98+
violinData.forEach(({ category, real, synthetic }) => {
99+
const xPos = xScale(category);
100+
if (xPos === undefined) return;
101+
102+
// Function to create violin path
103+
const createViolin = (values: number[], side: 'left' | 'right') => {
104+
// Create kernel density estimation
105+
const kde = kernelDensityEstimator(
106+
kernelEpanechnikov(0.2),
107+
yScale.ticks(50)
108+
);
109+
const density: [number, number][] = kde(values);
110+
const maxDensity = d3.max(density, d => d[1]) || 0;
111+
112+
const widthScale = d3
113+
.scaleLinear()
114+
.domain([0, maxDensity])
115+
.range([0, bandwidth]);
116+
117+
const area = d3
118+
.area<[number, number]>()
119+
.x0(d => {
120+
const width = widthScale(d[1]);
121+
return side === 'left' ? -width : 0;
122+
})
123+
.x1(d => {
124+
const width = widthScale(d[1]);
125+
return side === 'left' ? 0 : width;
126+
})
127+
.y(d => yScale(d[0]))
128+
.curve(d3.curveBasis);
129+
130+
return area(density);
131+
};
132+
133+
// Kernel functions
134+
function kernelDensityEstimator(
135+
kernel: (v: number) => number,
136+
X: number[]
137+
): (V: number[]) => [number, number][] {
138+
return function (V: number[]) {
139+
return X.map(x => [x, d3.mean(V, v => kernel(x - v)) || 0]);
140+
};
141+
}
142+
143+
function kernelEpanechnikov(k: number) {
144+
return function (v: number) {
145+
return Math.abs((v /= k)) <= 1
146+
? (0.75 * (1 - v * v)) / k
147+
: 0;
148+
};
149+
}
150+
151+
// Calculate center position for the violin plot
152+
const centerPos = xPos + xScale.bandwidth() / 2;
153+
154+
// Function to calculate quartiles
155+
const calculateQuartiles = (values: number[]) => {
156+
const sorted = [...values].sort((a, b) => a - b);
157+
return {
158+
q1: d3.quantile(sorted, 0.25) || 0,
159+
q2: d3.quantile(sorted, 0.5) || 0,
160+
q3: d3.quantile(sorted, 0.75) || 0,
161+
};
162+
};
163+
164+
// Function to draw quartile lines
165+
const drawQuartileLines = (
166+
values: number[],
167+
side: 'left' | 'right',
168+
color: string
169+
) => {
170+
const quartiles = calculateQuartiles(values);
171+
const maxWidth = bandwidth;
172+
173+
// Draw lines for each quartile
174+
Object.values(quartiles).forEach(q => {
175+
svg.append('line')
176+
.attr('x1', side === 'left' ? -maxWidth : 0)
177+
.attr('x2', side === 'left' ? 0 : maxWidth)
178+
.attr('y1', yScale(q))
179+
.attr('y2', yScale(q))
180+
.attr('transform', `translate(${centerPos}, 0)`)
181+
.style('stroke', color)
182+
.style('stroke-width', 1)
183+
.style('stroke-dasharray', '3,3');
184+
});
185+
};
186+
187+
// Draw real data violin (left side)
188+
if (real.length > 0) {
189+
svg.append('path')
190+
.attr('d', createViolin(real, 'left'))
191+
.attr('transform', `translate(${centerPos}, 0)`)
192+
.style('fill', 'steelblue')
193+
.style('opacity', 0.5);
194+
195+
drawQuartileLines(real, 'left', 'steelblue');
196+
}
197+
198+
// Draw synthetic data violin (right side)
199+
if (synthetic.length > 0) {
200+
svg.append('path')
201+
.attr('d', createViolin(synthetic, 'right'))
202+
.attr('transform', `translate(${centerPos}, 0)`)
203+
.style('fill', 'orange')
204+
.style('opacity', 0.5);
205+
206+
drawQuartileLines(synthetic, 'right', 'orange');
207+
}
208+
});
209+
210+
// Add axes
211+
svg.append('g')
212+
.attr('transform', `translate(0,${plotHeight})`)
213+
.call(d3.axisBottom(xScale))
214+
.selectAll('text')
215+
.attr('transform', 'rotate(-45)')
216+
.style('text-anchor', 'end')
217+
.attr('dx', '-.8em')
218+
.attr('dy', '.15em');
219+
220+
svg.append('g').call(d3.axisLeft(yScale));
221+
222+
// Add title
223+
svg.append('text')
224+
.attr('x', plotWidth / 2)
225+
.attr('y', -10)
226+
.attr('text-anchor', 'middle')
227+
.style('font-size', '12px')
228+
.style('font-weight', 'bold')
229+
.text(
230+
`${t('distribution.distributionOf')} ${numericColumn} ${t(
231+
'distribution.by'
232+
)} ${categoricalColumn}`
233+
);
234+
235+
// Add y-axis label
236+
svg.append('text')
237+
.attr('transform', 'rotate(-90)')
238+
.attr('y', -50)
239+
.attr('x', -plotHeight / 2)
240+
.attr('text-anchor', 'middle')
241+
.attr('font-size', '12px')
242+
.text(numericColumn);
243+
244+
// Add legend at fixed position relative to plot area
245+
const legend = svg
246+
.append('g')
247+
.attr('class', 'legend')
248+
.attr('transform', `translate(${plotWidth + 20}, 30)`);
249+
250+
// No need to adjust SVG width since we reserved space for legend
251+
252+
// Add background rectangle for legend
253+
legend
254+
.append('rect')
255+
.attr('x', -10)
256+
.attr('y', -10)
257+
.attr('width', 110)
258+
.attr('height', 55)
259+
.attr('rx', 5)
260+
.style('fill', 'white')
261+
.style('opacity', 0.7)
262+
.style('stroke', '#e2e8f0')
263+
.style('stroke-width', 1);
264+
265+
// Add legend items
266+
legend
267+
.append('rect')
268+
.attr('x', 0)
269+
.attr('y', 0)
270+
.attr('width', 15)
271+
.attr('height', 15)
272+
.style('fill', 'steelblue')
273+
.style('opacity', 0.5);
274+
275+
legend
276+
.append('text')
277+
.attr('x', 20)
278+
.attr('y', 12)
279+
.style('font-size', '12px')
280+
.text(t('distribution.realData'));
281+
282+
legend
283+
.append('rect')
284+
.attr('x', 0)
285+
.attr('y', 20)
286+
.attr('width', 15)
287+
.attr('height', 15)
288+
.style('fill', 'orange')
289+
.style('opacity', 0.5);
290+
291+
legend
292+
.append('text')
293+
.attr('x', 20)
294+
.attr('y', 32)
295+
.style('font-size', '12px')
296+
.text(t('distribution.syntheticData'));
297+
}, [
298+
containerWidth,
299+
categoricalColumn,
300+
numericColumn,
301+
realData,
302+
syntheticData,
303+
]);
304+
305+
useEffect(() => {
306+
const resizeObserver = new ResizeObserver(entries => {
307+
if (!entries || entries.length === 0) return;
308+
const { width } = entries[0].contentRect;
309+
if (width > 0) {
310+
setContainerWidth(width);
311+
}
312+
});
313+
314+
if (containerRef.current) {
315+
resizeObserver.observe(containerRef.current);
316+
}
317+
318+
return () => {
319+
if (containerRef.current) {
320+
resizeObserver.unobserve(containerRef.current);
321+
}
322+
};
323+
}, []);
324+
325+
return (
326+
<div
327+
ref={containerRef}
328+
style={{ width: '100%', display: 'flex', overflowX: 'auto' }}
329+
className={`chart-container min-h-[${height}px] flex-col`}
330+
>
331+
<svg ref={svgRef}></svg>
332+
</div>
333+
);
334+
};
335+
336+
export default ViolinChart;

0 commit comments

Comments
 (0)