Skip to content

Commit 47d6904

Browse files
authored
Merge pull request #21 from NGO-Algorithm-Audit/feature/univariate-etc-charts
Feature/univariate etc charts first iteration
2 parents 86f3bdb + 0aa1916 commit 47d6904

File tree

7 files changed

+254
-8
lines changed

7 files changed

+254
-8
lines changed

src/assets/bias-detection-python-code.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def run():
267267
if col != targetColumn and col != 'Cluster' and col != "":
268268
setResult(json.dumps({
269269
'type': 'heading',
270-
'key': 'biasAnalysis.distribution.heading',
270+
'headingKey': 'biasAnalysis.distribution.heading',
271271
'params': {'variable': col}
272272
}))
273273
@@ -281,7 +281,7 @@ def run():
281281
if col != targetColumn and col != 'Cluster' and col != "" and col in features:
282282
setResult(json.dumps({
283283
'type': 'heading',
284-
'key': 'biasAnalysis.distribution.heading',
284+
'headingKey': 'biasAnalysis.distribution.heading',
285285
'params': {'variable': col}
286286
}))
287287

src/assets/synthetic-data.tsx

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,15 @@ def run():
141141
))
142142
143143
144+
dtypes_dict = real_data.dtypes.to_dict()
145+
dtypes_dict = {k: 'float' if v == 'float64' else 'category' if v == 'O' else v for k, v in dtypes_dict.items()}
146+
147+
dtypes_dict['sex'] = 'category'
148+
real_data['sex'] = real_data['sex'].map({1: 'male', 2: 'female'})
149+
150+
cloned_real_data = real_data.copy()
151+
144152
if (sdgMethod == 'cart'):
145-
dtypes_dict = real_data.dtypes.to_dict()
146-
dtypes_dict = {k: 'float' if v == 'float64' else 'category' if v == 'O' else v for k, v in dtypes_dict.items()}
147153
label_encoders = {}
148154
for column in real_data.select_dtypes(include=['object']).columns:
149155
label_encoders[column] = LabelEncoder()
@@ -193,7 +199,15 @@ def run():
193199
194200
setResult(json.dumps({'type': 'heatmap', 'real': real_data.corr().to_json(orient="records"), 'synthetic': synthetic_data.corr().to_json(orient="records")}))
195201
196-
setResult(json.dumps({'type': 'distribution', 'real': real_data.to_json(orient="records"), 'synthetic': synthetic_data.to_json(orient="records")}))
202+
# copy dataframe and assign NaN to all values
203+
synth_df = real_data.copy()
204+
synth_df[:] = np.nan
205+
206+
# combine empty synthetic data with original data and with encoded data
207+
combined_data = pd.concat((real_data.assign(realOrSynthetic='real'), synth_df.assign(realOrSynthetic='synthetic')), keys=['real','synthetic'], names=['Data'])
208+
# combined_data_encoded = pd.concat((df_encoded.assign(realOrSynthetic='real_encoded'), synth_df.assign(realOrSynthetic='synthetic')), keys=['real_encoded','synthetic'], names=['Data'])
209+
210+
setResult(json.dumps({'type': 'distribution', 'real': cloned_real_data.to_json(orient="records"), 'synthetic': synthetic_data.to_json(orient="records"), 'dataTypes': json.dumps(dtypes_dict), 'combined_data' : combined_data.to_json(orient="records")}))
197211
198212
return
199213
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import React from 'react';
2+
import CountBarChart from './graphs/CountBarChart';
3+
4+
export interface ChartProps {
5+
realData: object[];
6+
syntheticData: object[];
7+
dataTypes: Record<string, string>;
8+
combined_data: object;
9+
comparison: boolean;
10+
}
11+
12+
export const UnivariateCharts = (props: ChartProps) => {
13+
if (
14+
props.realData &&
15+
props.syntheticData &&
16+
props.dataTypes &&
17+
props.combined_data &&
18+
props.realData.length > 0 &&
19+
props.syntheticData.length > 0
20+
) {
21+
const realData = props.realData;
22+
//const syntheticData = props.syntheticData;
23+
const dataTypes = props.dataTypes;
24+
//const combinedData = props.combined_data;
25+
const realDataKeys = Object.keys(realData[0]);
26+
//const syntheticDataKeys = Object.keys(syntheticData[0]);
27+
//const combinedDataKeys = Object.keys(combinedData[0]);
28+
29+
/*
30+
31+
<SingleBarChart
32+
key={index}
33+
data={barchartData}
34+
title={resultItem.title ?? ''}
35+
/>
36+
37+
*/
38+
39+
return realDataKeys
40+
.filter(column => {
41+
return column != 'realOrSynthetic';
42+
})
43+
.map((column, index) => {
44+
//const isFloatColumn = dataTypes[column] == 'float';
45+
//if (true) {
46+
return (
47+
<React.Fragment key={index}>
48+
<h2>{column}</h2>
49+
<CountBarChart
50+
key={index}
51+
column={column}
52+
realData={realData.map(
53+
row =>
54+
(row as unknown as Record<string, object>)[
55+
column
56+
] as unknown as number
57+
)}
58+
/>
59+
</React.Fragment>
60+
);
61+
//}
62+
63+
return (
64+
<div key={index}>
65+
{column} : {dataTypes[column]}
66+
{}
67+
</div>
68+
);
69+
});
70+
}
71+
return <div></div>;
72+
};

src/components/componentMapper.tsx

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import { Accordion } from './ui/accordion';
1010
import { useTranslation } from 'react-i18next';
1111
import HeatMapChart from './graphs/HeatMap';
1212
import DistributionBarChart from './graphs/DistributionBarChart';
13+
import { UnivariateCharts } from './UnivariateCharts';
1314

1415
const createArrayFromPythonDictionary = (dict: Record<string, number>) => {
1516
const resultArray = [];
@@ -131,7 +132,10 @@ export default function ComponentMapper({
131132
className="text-gray-800 font-semibold"
132133
>
133134
{resultItem.headingKey
134-
? t(resultItem.headingKey)
135+
? t(
136+
resultItem.headingKey,
137+
resultItem.params
138+
)
135139
: resultItem.data}
136140
</h5>
137141
);
@@ -174,10 +178,26 @@ export default function ComponentMapper({
174178
);
175179
}
176180
case 'distribution': {
181+
console.log(
182+
'distribution',
183+
JSON.parse(resultItem.dataTypes)
184+
);
185+
177186
const realData = JSON.parse(resultItem.real);
178187
const syntheticData = JSON.parse(resultItem.synthetic);
188+
189+
console.log('realData', realData);
179190
return (
180191
<div key={`distribution-${index}`}>
192+
<UnivariateCharts
193+
realData={realData}
194+
syntheticData={syntheticData}
195+
dataTypes={JSON.parse(resultItem.dataTypes)}
196+
combined_data={JSON.parse(
197+
resultItem.combined_data
198+
)}
199+
comparison={false}
200+
/>
181201
{realData.length === 0 ||
182202
syntheticData.length === 0
183203
? null
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import { useEffect, useRef, useState } from 'react';
2+
import * as d3 from 'd3';
3+
import { useTranslation } from 'react-i18next';
4+
5+
interface CountBarChartProps {
6+
column: string;
7+
realData: (string | number)[];
8+
}
9+
10+
const margin = { top: 30, right: 50, bottom: 60, left: 50 }; // Increased bottom margin for rotated labels
11+
const height = 300 - margin.top - margin.bottom;
12+
13+
const CountBarChart = ({ column, realData }: CountBarChartProps) => {
14+
const svgRef = useRef<SVGSVGElement>(null);
15+
const containerRef = useRef<HTMLDivElement>(null);
16+
const [containerWidth, setContainerWidth] = useState(800);
17+
const { t } = useTranslation();
18+
19+
useEffect(() => {
20+
const plotWidth = containerWidth - margin.left - margin.right;
21+
const plotHeight = height - margin.top - margin.bottom;
22+
23+
// Determine if data is categorical or numerical
24+
const isNumerical = realData.every(d => typeof d === 'number');
25+
26+
// Process data based on type
27+
let processedData;
28+
if (isNumerical) {
29+
// For numerical data, create bins
30+
const numericData = realData as number[];
31+
const bins = d3
32+
.bin()
33+
.domain(d3.extent(numericData) as [number, number])
34+
.thresholds(10)(numericData);
35+
36+
processedData = bins.map(bin => ({
37+
key: `${bin.x0?.toFixed(2)} - ${bin.x1?.toFixed(2)}`,
38+
value: bin.length,
39+
}));
40+
} else {
41+
// For categorical data, count occurrences
42+
const counts = d3.rollup(
43+
realData,
44+
v => v.length,
45+
d => d
46+
);
47+
processedData = Array.from(counts, ([key, value]) => ({
48+
key,
49+
value,
50+
}));
51+
}
52+
53+
// Clear previous content
54+
d3.select(svgRef.current).selectAll('*').remove();
55+
56+
// Create SVG
57+
const svg = d3
58+
.select(svgRef.current)
59+
.attr('width', containerWidth)
60+
.attr('height', height + margin.top + margin.bottom)
61+
.append('g')
62+
.attr('transform', `translate(${margin.left},${margin.top})`);
63+
64+
// Create scales
65+
const xScale = d3
66+
.scaleBand()
67+
.domain(processedData.map(d => String(d.key)))
68+
.range([0, plotWidth])
69+
.padding(0.1);
70+
71+
const yScale = d3
72+
.scaleLinear()
73+
.domain([0, d3.max(processedData, d => d.value) || 0])
74+
.range([plotHeight, 0]);
75+
76+
// Add bars
77+
svg.selectAll('.bar')
78+
.data(processedData)
79+
.enter()
80+
.append('rect')
81+
.attr('class', 'bar')
82+
.attr('x', d => xScale(String(d.key)) || 0)
83+
.attr('y', d => yScale(d.value))
84+
.attr('width', xScale.bandwidth())
85+
.attr('height', d => plotHeight - yScale(d.value))
86+
.style('fill', 'steelblue')
87+
.style('opacity', 0.5);
88+
89+
// Add axes
90+
svg.append('g')
91+
.attr('transform', `translate(0,${plotHeight})`)
92+
.call(d3.axisBottom(xScale))
93+
.selectAll('text')
94+
.attr('transform', 'rotate(-45)')
95+
.style('text-anchor', 'end')
96+
.attr('dx', '-.8em')
97+
.attr('dy', '.15em');
98+
99+
svg.append('g').call(d3.axisLeft(yScale));
100+
101+
// Add title
102+
svg.append('text')
103+
.attr('x', plotWidth / 2)
104+
.attr('y', -10)
105+
.attr('text-anchor', 'middle')
106+
.style('font-size', '12px')
107+
.style('font-weight', 'bold')
108+
.text(`${t('distribution.countFor')} ${column}`);
109+
}, [containerWidth, column, realData]);
110+
111+
useEffect(() => {
112+
const resizeObserver = new ResizeObserver(entries => {
113+
if (!entries || entries.length === 0) return;
114+
const { width } = entries[0].contentRect;
115+
setContainerWidth(width);
116+
});
117+
118+
if (containerRef.current) {
119+
resizeObserver.observe(containerRef.current);
120+
}
121+
122+
return () => {
123+
if (containerRef.current) {
124+
resizeObserver.unobserve(containerRef.current);
125+
}
126+
};
127+
}, []);
128+
129+
return (
130+
<div
131+
ref={containerRef}
132+
style={{ width: '100%', display: 'flex', overflowX: 'auto' }}
133+
className={`min-h-[${height}px] flex-col`}
134+
>
135+
<svg ref={svgRef}></svg>
136+
</div>
137+
);
138+
};
139+
140+
export default CountBarChart;

src/components/graphs/SingleBarChart.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ const SingleBarChart = ({ title, data }: SingleBarChartProps) => {
2424
const svgRef = useRef<SVGSVGElement>(null); // Reference to the SVG element
2525
const containerRef = useRef<HTMLDivElement>(null); // Reference to the container div
2626
const [containerWidth, setContainerWidth] = useState(800); // Default container width
27-
27+
console.log('SingleBarChart', data);
2828
// Create x-axis scale using d3.scaleBand, with padding for spacing between bars
2929
const x0 = useMemo(
3030
() =>

src/locales/en.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"error": "Sorry, something went wrong.",
1111
"loadingMessage": "Setting up environment...",
1212
"mostBiasedCluster": "Most biased\n cluster",
13-
"cluster": "Cluster",
13+
"cluster": "Cluster {{value}}",
1414
"exportToPDF": "Export to PDF",
1515
"exportToJSON": "Export synthetic data to JSON",
1616
"downloadButton": "Download",

0 commit comments

Comments
 (0)