Skip to content

Commit 0aa1916

Browse files
committed
numerical vs categorical histograms in countbarcharr
1 parent 3348e9c commit 0aa1916

File tree

3 files changed

+94
-79
lines changed

3 files changed

+94
-79
lines changed

src/assets/synthetic-data.tsx

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,15 @@ def run():
140140
{'type': 'data-set-preview', 'data': ''}
141141
))
142142
143+
143144
dtypes_dict = real_data.dtypes.to_dict()
144145
dtypes_dict = {k: 'float' if v == 'float64' else 'category' if v == 'O' else v for k, v in dtypes_dict.items()}
145146
147+
dtypes_dict['sex'] = 'category'
148+
real_data['sex'] = real_data['sex'].map({1: 'male', 2: 'female'})
149+
150+
cloned_real_data = real_data.copy()
151+
146152
if (sdgMethod == 'cart'):
147153
label_encoders = {}
148154
for column in real_data.select_dtypes(include=['object']).columns:
@@ -201,7 +207,7 @@ def run():
201207
combined_data = pd.concat((real_data.assign(realOrSynthetic='real'), synth_df.assign(realOrSynthetic='synthetic')), keys=['real','synthetic'], names=['Data'])
202208
# combined_data_encoded = pd.concat((df_encoded.assign(realOrSynthetic='real_encoded'), synth_df.assign(realOrSynthetic='synthetic')), keys=['real_encoded','synthetic'], names=['Data'])
203209
204-
setResult(json.dumps({'type': 'distribution', 'real': real_data.to_json(orient="records"), 'synthetic': synthetic_data.to_json(orient="records"), 'dataTypes': json.dumps(dtypes_dict), 'combined_data' : combined_data.to_json(orient="records")}))
210+
setResult(json.dumps({'type': 'distribution', 'real': cloned_real_data.to_json(orient="records"), 'synthetic': synthetic_data.to_json(orient="records"), 'dataTypes': json.dumps(dtypes_dict), 'combined_data' : combined_data.to_json(orient="records")}))
205211
206212
return
207213

src/components/UnivariateCharts.tsx

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,24 @@ export const UnivariateCharts = (props: ChartProps) => {
4141
return column != 'realOrSynthetic';
4242
})
4343
.map((column, index) => {
44-
const isFloatColumn = dataTypes[column] == 'float';
45-
if (isFloatColumn) {
46-
return (
47-
<React.Fragment key={index}>
48-
<h2>{column}</h2>
49-
<CountBarChart
50-
key={index}
51-
column={column}
52-
realData={realData.map(
53-
row =>
54-
(
55-
row as unknown as Record<
56-
string,
57-
object
58-
>
59-
)[column] as unknown as number
60-
)}
61-
/>
62-
</React.Fragment>
63-
);
64-
}
44+
//const isFloatColumn = dataTypes[column] == 'float';
45+
//if (true) {
46+
return (
47+
<React.Fragment key={index}>
48+
<h2>{column}</h2>
49+
<CountBarChart
50+
key={index}
51+
column={column}
52+
realData={realData.map(
53+
row =>
54+
(row as unknown as Record<string, object>)[
55+
column
56+
] as unknown as number
57+
)}
58+
/>
59+
</React.Fragment>
60+
);
61+
//}
6562

6663
return (
6764
<div key={index}>

src/components/graphs/CountBarChart.tsx

Lines changed: 69 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,116 +4,128 @@ import { useTranslation } from 'react-i18next';
44

55
interface CountBarChartProps {
66
column: string;
7-
realData: number[];
7+
realData: (string | number)[];
88
}
99

10-
// Define margins for the chart
11-
const margin = { top: 30, right: 50, bottom: 40, left: 50 };
12-
// Define height for the chart, adjusting for margins
10+
const margin = { top: 30, right: 50, bottom: 60, left: 50 }; // Increased bottom margin for rotated labels
1311
const height = 300 - margin.top - margin.bottom;
1412

15-
// Define width of bars and adjust for screenwidth
16-
// const barWidth = 0.05 * window.innerWidth < 40 ? 40 : 0.05 * window.innerWidth;
17-
// const barGap = 5;
18-
1913
const CountBarChart = ({ column, realData }: CountBarChartProps) => {
20-
const svgRef = useRef<SVGSVGElement>(null); // Reference to the SVG element
21-
const containerRef = useRef<HTMLDivElement>(null); // Reference to the container div
22-
const [containerWidth, setContainerWidth] = useState(800); // Default container width
14+
const svgRef = useRef<SVGSVGElement>(null);
15+
const containerRef = useRef<HTMLDivElement>(null);
16+
const [containerWidth, setContainerWidth] = useState(800);
2317
const { t } = useTranslation();
18+
2419
useEffect(() => {
2520
const plotWidth = containerWidth - margin.left - margin.right;
2621
const plotHeight = height - margin.top - margin.bottom;
2722

28-
const xScale = d3
29-
.scaleLinear()
30-
.domain([d3.min(realData) || 0, d3.max(realData) || 1])
31-
.range([0, plotWidth]);
32-
33-
const binsReal = d3
34-
.bin()
35-
.domain(xScale.domain() as [number, number])
36-
.thresholds(30)(realData);
23+
// Determine if data is categorical or numerical
24+
const isNumerical = realData.every(d => typeof d === 'number');
25+
26+
// Process data based on type
27+
let processedData;
28+
if (isNumerical) {
29+
// For numerical data, create bins
30+
const numericData = realData as number[];
31+
const bins = d3
32+
.bin()
33+
.domain(d3.extent(numericData) as [number, number])
34+
.thresholds(10)(numericData);
35+
36+
processedData = bins.map(bin => ({
37+
key: `${bin.x0?.toFixed(2)} - ${bin.x1?.toFixed(2)}`,
38+
value: bin.length,
39+
}));
40+
} else {
41+
// For categorical data, count occurrences
42+
const counts = d3.rollup(
43+
realData,
44+
v => v.length,
45+
d => d
46+
);
47+
processedData = Array.from(counts, ([key, value]) => ({
48+
key,
49+
value,
50+
}));
51+
}
3752

38-
// Clear any previous SVG content to avoid overlapping elements
53+
// Clear previous content
3954
d3.select(svgRef.current).selectAll('*').remove();
4055

41-
// Create the SVG container and set its dimensions
56+
// Create SVG
4257
const svg = d3
4358
.select(svgRef.current)
44-
.attr('class', `min-h-[${height}px]`)
4559
.attr('width', containerWidth)
4660
.attr('height', height + margin.top + margin.bottom)
4761
.append('g')
4862
.attr('transform', `translate(${margin.left},${margin.top})`);
49-
// Add axes
50-
svg.append('g')
51-
.attr('transform', `translate(0, ${plotHeight})`)
52-
.call(d3.axisBottom(xScale));
53-
54-
svg.append('defs')
55-
.append('style')
56-
.attr('type', 'text/css')
57-
.text(
58-
"@import url('https://fonts.googleapis.com/css2?family=Avenir:wght@600');"
59-
);
6063

61-
//const realDensityFactor = 1 / realData.length;
64+
// Create scales
65+
const xScale = d3
66+
.scaleBand()
67+
.domain(processedData.map(d => String(d.key)))
68+
.range([0, plotWidth])
69+
.padding(0.1);
6270

6371
const yScale = d3
6472
.scaleLinear()
65-
.domain([0, d3.max([...binsReal.map(bin => bin.length)]) || 1])
73+
.domain([0, d3.max(processedData, d => d.value) || 0])
6674
.range([plotHeight, 0]);
6775

68-
// Add axes
69-
svg.append('g')
70-
.attr('transform', `translate(0, ${plotHeight})`)
71-
.call(d3.axisBottom(xScale));
72-
svg.append('g').call(d3.axisLeft(yScale));
73-
74-
// Draw real data histogram
75-
svg.selectAll('.bar-real')
76-
.data(binsReal)
76+
// Add bars
77+
svg.selectAll('.bar')
78+
.data(processedData)
7779
.enter()
7880
.append('rect')
79-
.attr('class', 'bar-real')
80-
.attr('x', d => xScale(d.x0 || 0))
81-
.attr('y', d => yScale(d.length))
82-
.attr('width', d => xScale(d.x1 || 0) - xScale(d.x0 || 0) - 1)
83-
.attr('height', d => plotHeight - yScale(d.length))
81+
.attr('class', 'bar')
82+
.attr('x', d => xScale(String(d.key)) || 0)
83+
.attr('y', d => yScale(d.value))
84+
.attr('width', xScale.bandwidth())
85+
.attr('height', d => plotHeight - yScale(d.value))
8486
.style('fill', 'steelblue')
8587
.style('opacity', 0.5);
8688

89+
// Add axes
90+
svg.append('g')
91+
.attr('transform', `translate(0,${plotHeight})`)
92+
.call(d3.axisBottom(xScale))
93+
.selectAll('text')
94+
.attr('transform', 'rotate(-45)')
95+
.style('text-anchor', 'end')
96+
.attr('dx', '-.8em')
97+
.attr('dy', '.15em');
98+
99+
svg.append('g').call(d3.axisLeft(yScale));
100+
87101
// Add title
88102
svg.append('text')
89103
.attr('x', plotWidth / 2)
90-
.attr('y', 10)
104+
.attr('y', -10)
91105
.attr('text-anchor', 'middle')
92106
.style('font-size', '12px')
93107
.style('font-weight', 'bold')
94-
.text(`${t('distribution.distributionFor')} ${column}`);
108+
.text(`${t('distribution.countFor')} ${column}`);
95109
}, [containerWidth, column, realData]);
96110

97111
useEffect(() => {
98-
// Set up the ResizeObserver to track changes in the container's size
99112
const resizeObserver = new ResizeObserver(entries => {
100113
if (!entries || entries.length === 0) return;
101114
const { width } = entries[0].contentRect;
102-
setContainerWidth(width); // Update the state with the new container width
115+
setContainerWidth(width);
103116
});
104117

105118
if (containerRef.current) {
106-
resizeObserver.observe(containerRef.current); // Start observing the container
119+
resizeObserver.observe(containerRef.current);
107120
}
108121

109122
return () => {
110123
if (containerRef.current) {
111-
resizeObserver.unobserve(containerRef.current); // Cleanup on component unmount
124+
resizeObserver.unobserve(containerRef.current);
112125
}
113126
};
114127
}, []);
115128

116-
// Render the chart container and SVG element with horizontal scroll if needed
117129
return (
118130
<div
119131
ref={containerRef}

0 commit comments

Comments
 (0)