Skip to content

Commit d7cd623

Browse files
authored
Merge pull request #12 from NGO-Algorithm-Audit/feature/distribution-chart
Feature/distribution chart
2 parents a6f1676 + 8f0acb5 commit d7cd623

File tree

3 files changed

+257
-9
lines changed

3 files changed

+257
-9
lines changed

src/assets/synthetic-data.tsx

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -618,16 +618,10 @@ def run():
618618
))
619619
setResult(json.dumps({'type': 'table', 'data': synthetic_data.head().to_json(orient="records")}))
620620
621-
np.random.seed(42)
622-
# heatmap = np.random.rand(100, 10)
621+
setResult(json.dumps({'type': 'heatmap', 'real': real_data.corr().to_json(orient="records"), 'synthetic': synthetic_data.corr().to_json(orient="records")}))
623622
624-
# Compute the (test) correlation matrix
625-
# correlation_matrix = np.corrcoef(heatmap, rowvar=False)
626-
# setResult(json.dumps({'type': 'heatmap', 'data': correlation_matrix.tolist()}))
623+
setResult(json.dumps({'type': 'distribution', 'real': real_data.to_json(orient="records"), 'synthetic': synthetic_data.to_json(orient="records")}))
627624
628-
setResult(json.dumps({'type': 'heatmap', 'real': real_data.corr().to_json(orient="records"), 'synthetic': synthetic_data.corr().to_json(orient="records")}))
629-
630-
# setResult(json.dumps({'type': 'heatmap', 'data': synthetic_data.corr().to_json(orient="records")}))
631625
return
632626
633627

src/components/componentMapper.tsx

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import { Fragment } from 'react/jsx-runtime';
99
import { Accordion } from './ui/accordion';
1010
import { useTranslation } from 'react-i18next';
1111
import HeatMapChart from './graphs/HeatMap';
12+
import DistributionBarChart from './graphs/DistributionBarChart';
1213

1314
const createArrayFromPythonDictionary = (dict: Record<string, number>) => {
1415
const resultArray = [];
@@ -172,6 +173,57 @@ export default function ComponentMapper({
172173
</ErrorBoundary>
173174
);
174175
}
176+
case 'distribution': {
177+
const realData = JSON.parse(resultItem.real);
178+
const syntheticData = JSON.parse(resultItem.synthetic);
179+
return (
180+
<div key={`distribution-${index}`}>
181+
{realData.length === 0 ||
182+
syntheticData.length === 0
183+
? null
184+
: Object.keys(realData[0]).map(
185+
(
186+
columnName: string,
187+
columnIndex: number
188+
) => {
189+
const realDataColumn =
190+
realData.map(
191+
(
192+
row: Record<
193+
string,
194+
number
195+
>
196+
) => row[columnName]
197+
);
198+
const syntheticDataColumn =
199+
syntheticData.map(
200+
(
201+
row: Record<
202+
string,
203+
number
204+
>
205+
) => row[columnName]
206+
);
207+
return (
208+
<ErrorBoundary
209+
key={columnIndex}
210+
>
211+
<DistributionBarChart
212+
realData={
213+
realDataColumn
214+
}
215+
syntheticData={
216+
syntheticDataColumn
217+
}
218+
column={columnName}
219+
/>
220+
</ErrorBoundary>
221+
);
222+
}
223+
)}
224+
</div>
225+
);
226+
}
175227
case 'heatmap': {
176228
/*
177229
resultItem.real
@@ -202,7 +254,10 @@ export default function ComponentMapper({
202254
data: syntheticData,
203255
} = createHeatmapdata(resultItem.synthetic);
204256
return (
205-
<div className="grid lg:grid-cols-[50%_50%] grid-cols-[100%]">
257+
<div
258+
key={`heatmap-${index}`}
259+
className="grid lg:grid-cols-[50%_50%] grid-cols-[100%]"
260+
>
206261
<div className="col-[1]">
207262
<h2 className="pb-2">
208263
{t('heatmap.realdata')}
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import { useEffect, useRef, useState } from 'react';
2+
import * as d3 from 'd3';
3+
4+
interface DistributionBarChartProps {
5+
column: string;
6+
realData: number[];
7+
syntheticData: number[];
8+
}
9+
10+
// Define margins for the chart
11+
const margin = { top: 10, right: 50, bottom: 40, left: 50 };
12+
// Define height for the chart, adjusting for margins
13+
const height = 300 - margin.top - margin.bottom;
14+
15+
// Define width of bars and adjust for screenwidth
16+
// const barWidth = 0.05 * window.innerWidth < 40 ? 40 : 0.05 * window.innerWidth;
17+
// const barGap = 5;
18+
19+
const DistributionBarChart = ({
20+
column,
21+
realData,
22+
syntheticData,
23+
}: DistributionBarChartProps) => {
24+
const svgRef = useRef<SVGSVGElement>(null); // Reference to the SVG element
25+
const containerRef = useRef<HTMLDivElement>(null); // Reference to the container div
26+
const [containerWidth, setContainerWidth] = useState(800); // Default container width
27+
28+
// Create x-axis scale using d3.scaleBand, with padding for spacing between bars
29+
// const x0 = useMemo(
30+
// () =>
31+
// d3
32+
// .scaleBand()
33+
// .domain(data.map(d => d.name))
34+
// .range([
35+
// 0,
36+
// Math.max(
37+
// containerWidth - margin.right,
38+
// data.length * (barWidth + barGap)
39+
// ),
40+
// ])
41+
// .padding(0.2),
42+
// [data, containerWidth]
43+
// );
44+
45+
// // Create y-axis scale using d3.scaleLinear, with a range from the height to 0
46+
// const y = useMemo(
47+
// () =>
48+
// d3
49+
// .scaleLinear()
50+
// .domain([
51+
// d3.min(data, d => d.values) ?? 0, // Minimum value in the dataset (or 0 if undefined)
52+
// d3.max(data, d => d.values) ?? 0, // Maximum value in the dataset (or 0 if undefined)
53+
// ])
54+
// .nice() // Rounds the domain to nice round values
55+
// .range([height, 0]),
56+
// [data]
57+
// );
58+
59+
useEffect(() => {
60+
const plotWidth = containerWidth - margin.left - margin.right;
61+
const plotHeight = height - margin.top - margin.bottom;
62+
63+
const combinedData = [...realData, ...syntheticData];
64+
const xScale = d3
65+
.scaleLinear()
66+
.domain([d3.min(combinedData) || 0, d3.max(combinedData) || 1])
67+
.range([0, plotWidth]);
68+
69+
const binsReal = d3
70+
.bin()
71+
.domain(xScale.domain() as [number, number])
72+
.thresholds(30)(realData);
73+
74+
const binsSynthetic = d3
75+
.bin()
76+
.domain(xScale.domain() as [number, number])
77+
.thresholds(30)(syntheticData);
78+
79+
// Clear any previous SVG content to avoid overlapping elements
80+
d3.select(svgRef.current).selectAll('*').remove();
81+
82+
// Create the SVG container and set its dimensions
83+
const svg = d3
84+
.select(svgRef.current)
85+
.attr('class', `min-h-[${height}px]`)
86+
.attr('width', containerWidth)
87+
.attr('height', height + margin.top + margin.bottom)
88+
.append('g')
89+
.attr('transform', `translate(${margin.left},${margin.top})`);
90+
// Add axes
91+
svg.append('g')
92+
.attr('transform', `translate(0, ${plotHeight})`)
93+
.call(d3.axisBottom(xScale));
94+
95+
svg.append('defs')
96+
.append('style')
97+
.attr('type', 'text/css')
98+
.text(
99+
"@import url('https://fonts.googleapis.com/css2?family=Avenir:wght@600');"
100+
);
101+
102+
const realDensityFactor = 1 / realData.length;
103+
const syntheticDensityFactor = 1 / syntheticData.length;
104+
105+
const yScale = d3
106+
.scaleLinear()
107+
.domain([
108+
0,
109+
d3.max([
110+
...binsReal.map(bin => bin.length * realDensityFactor),
111+
...binsSynthetic.map(
112+
bin => bin.length * syntheticDensityFactor
113+
),
114+
]) || 1,
115+
])
116+
.range([plotHeight, 0]);
117+
118+
// Add axes
119+
svg.append('g')
120+
.attr('transform', `translate(0, ${plotHeight})`)
121+
.call(d3.axisBottom(xScale));
122+
svg.append('g').call(d3.axisLeft(yScale));
123+
124+
// Draw real data histogram
125+
svg.selectAll('.bar-real')
126+
.data(binsReal)
127+
.enter()
128+
.append('rect')
129+
.attr('class', 'bar-real')
130+
.attr('x', d => xScale(d.x0 || 0))
131+
.attr('y', d => yScale(d.length * realDensityFactor))
132+
.attr('width', d => xScale(d.x1 || 0) - xScale(d.x0 || 0) - 1)
133+
.attr(
134+
'height',
135+
d => plotHeight - yScale(d.length * realDensityFactor)
136+
)
137+
.style('fill', 'steelblue')
138+
.style('opacity', 0.5);
139+
140+
// Draw synthetic data histogram
141+
svg.selectAll('.bar-synthetic')
142+
.data(binsSynthetic)
143+
.enter()
144+
.append('rect')
145+
.attr('class', 'bar-synthetic')
146+
.attr('x', d => xScale(d.x0 || 0))
147+
.attr('y', d => yScale(d.length * syntheticDensityFactor))
148+
.attr('width', d => xScale(d.x1 || 0) - xScale(d.x0 || 0) - 1)
149+
.attr(
150+
'height',
151+
d => plotHeight - yScale(d.length * syntheticDensityFactor)
152+
)
153+
.style('fill', 'orange')
154+
.style('opacity', 0.5);
155+
156+
// Add title
157+
svg.append('text')
158+
.attr('x', plotWidth / 2)
159+
.attr('y', 10)
160+
.attr('text-anchor', 'middle')
161+
.style('font-size', '12px')
162+
.style('font-weight', 'bold')
163+
.text(`Distribution for ${column}`);
164+
165+
// Add a legend label for the mean line
166+
}, [containerWidth, column, realData, syntheticData]);
167+
168+
useEffect(() => {
169+
// Set up the ResizeObserver to track changes in the container's size
170+
const resizeObserver = new ResizeObserver(entries => {
171+
if (!entries || entries.length === 0) return;
172+
const { width } = entries[0].contentRect;
173+
setContainerWidth(width); // Update the state with the new container width
174+
});
175+
176+
if (containerRef.current) {
177+
resizeObserver.observe(containerRef.current); // Start observing the container
178+
}
179+
180+
return () => {
181+
if (containerRef.current) {
182+
resizeObserver.unobserve(containerRef.current); // Cleanup on component unmount
183+
}
184+
};
185+
}, []);
186+
187+
// Render the chart container and SVG element with horizontal scroll if needed
188+
return (
189+
<div
190+
ref={containerRef}
191+
style={{ width: '100%', display: 'flex', overflowX: 'auto' }}
192+
className={`min-h-[${height}px] flex-col`}
193+
>
194+
<svg ref={svgRef}></svg>
195+
</div>
196+
);
197+
};
198+
199+
export default DistributionBarChart;

0 commit comments

Comments
 (0)