Skip to content

Commit 93667d6

Browse files
committed
demo run also uses parameters from UI where applicable
1 parent 4e989b7 commit 93667d6

File tree

3 files changed

+38
-76
lines changed

3 files changed

+38
-76
lines changed

src/assets/bias-detection-python-code.tsx

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -165,58 +165,6 @@ def chi2_test_on_cluster(decoded_X_test, bias_score, cluster_label):
165165
166166
return comparisons
167167
168-
def diffDataframe(df, features, type=None, cluster1=None, cluster2=None):
169-
'''
170-
Creates difference dataframe, for numerical and categorical
171-
data: Takes dataframe of two clusters of interest and
172-
computes difference in means. Default to analyze most deviating
173-
cluster vs rest of the dataset, except specified otherwise.
174-
'''
175-
# Cluster comparison (optional)
176-
if cluster1 != None and cluster2 != None:
177-
df1 = df[df['Cluster'] == cluster1]
178-
df2 = df[df['Cluster'] == cluster2]
179-
else:
180-
df1 = df[df['Cluster'] == 0]
181-
df2 = df[df['Cluster'] != 0]
182-
183-
n_df1 = df1.shape[0]
184-
n_df2 = df2.shape[0]
185-
186-
diff_dict = {}
187-
CI_dict = {}
188-
189-
for feat in features:
190-
sample1 = df1[feat]
191-
sample2 = df2[feat]
192-
193-
if type == 'Numerical':
194-
mean1 = np.mean(sample1)
195-
mean2 = np.mean(sample2)
196-
diff = mean1 - mean2
197-
diff_dict[feat] = diff
198-
else:
199-
freq1 = sample1.value_counts()
200-
freq2 = sample2.value_counts()
201-
diff = freq1 - freq2
202-
diff_dict[feat] = diff
203-
204-
if type == 'Numerical':
205-
pd.set_option('display.float_format', lambda x: '%.5f' % x)
206-
diff_df = pd.DataFrame.from_dict(diff_dict, orient='index', columns=['Difference'])
207-
else:
208-
diff_df = pd.DataFrame()
209-
pd.set_option('display.float_format', lambda x: '%.5f' % x)
210-
211-
for _, value in diff_dict.items():
212-
df_temp = pd.DataFrame(value)
213-
diff_df = pd.concat([diff_df,df_temp], axis=0,)
214-
215-
diff_df = diff_df.fillna(0)
216-
diff_df.columns = ['Difference']
217-
218-
return(diff_df)
219-
220168
def run():
221169
csv_data = StringIO(data)
222170
df = pd.read_csv(csv_data)
@@ -227,7 +175,7 @@ def run():
227175
if isDemo:
228176
bias_score = "false_positive"
229177
localDataType = "categorical"
230-
localIterations = 20
178+
localIterations = iterations # 20
231179
232180
print (f"Using demo parameters: bias_score={bias_score}, targetColumn={targetColumn}, dataType={localDataType}, iterations={iterations}")
233181
@@ -297,7 +245,7 @@ def run():
297245
print(f"X_train shape: {X_train.shape}")
298246
299247
if isDemo:
300-
localClusterSize = X_train.shape[0]*0.01
248+
localClusterSize = clusterSize # X_train.shape[0]*0.01
301249
else:
302250
localClusterSize = clusterSize
303251

src/components/BiasSettings.tsx

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import {
88
} from '@/components/ui/select';
99
import { Slider } from '@/components/ui/slider';
1010
import { RadioGroup, RadioGroupItem } from '@/components/ui/radio-group';
11-
import CSVReader, { csvReader } from './CSVReader';
11+
import CSVReader from './CSVReader';
1212
import { useEffect, useState } from 'react';
1313
import { Button } from './ui/button';
1414
import { ArrowDown, ArrowRight, InfoIcon } from 'lucide-react';
@@ -58,7 +58,14 @@ export default function BiasSettings({
5858
isErrorDuringAnalysis,
5959
}: {
6060
onRun: (params: BiasDetectionParameters) => void;
61-
onDataLoad: csvReader['onChange'];
61+
onDataLoad: (
62+
data: Record<string, string>[],
63+
stringified: string,
64+
fileName: string,
65+
demo?: boolean,
66+
columnsCount?: number,
67+
params?: BiasDetectionParameters
68+
) => void;
6269
isLoading: boolean;
6370
isErrorDuringAnalysis: boolean;
6471
isInitialised: boolean;
@@ -144,7 +151,18 @@ export default function BiasSettings({
144151
file.data as Record<string, string>[],
145152
Papa.unparse(file.data),
146153
'demo',
147-
true
154+
true,
155+
undefined,
156+
{
157+
clusterSize: clusters[0],
158+
iterations: iter[0],
159+
targetColumn: '',
160+
dataType: '',
161+
higherIsBetter:
162+
form.getValues().whichPerformanceMetricValueIsBetter ===
163+
'higher',
164+
isDemo: true,
165+
}
148166
);
149167
};
150168

src/routes/BiasDetection.tsx

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import { useEffect, useRef, useState } from 'react';
22
import { pythonCode } from '@/assets/bias-detection-python-code';
33
import { usePython } from '@/components/pyodide/use-python';
44
import BiasSettings from '@/components/BiasSettings';
5-
import { csvReader } from '@/components/CSVReader';
65
import { cn } from '@/lib/utils';
76
import ComponentMapper from '@/components/componentMapper';
87
import { useReactToPrint } from 'react-to-print';
@@ -45,11 +44,14 @@ const PAGE_STYLE = `
4544
`;
4645

4746
export default function BiasDetection() {
48-
const [data, setData] = useState<CSVData>({
47+
const [data, setData] = useState<
48+
CSVData & { params?: BiasDetectionParameters }
49+
>({
4950
data: [],
5051
stringified: '',
5152
fileName: '',
5253
demo: false,
54+
params: undefined,
5355
});
5456
const { t, i18n } = useTranslation();
5557

@@ -80,14 +82,15 @@ export default function BiasDetection() {
8082
higherIsBetter: false,
8183
isDemo: false,
8284
});
83-
84-
const onFileLoad: csvReader['onChange'] = (
85-
data,
86-
stringified,
87-
fileName,
88-
demo
89-
) => {
90-
setData({ data, stringified, fileName, demo });
85+
const onFileLoad: (
86+
data: Record<string, string>[],
87+
stringified: string,
88+
fileName: string,
89+
demo?: boolean,
90+
columnsCount?: number,
91+
params?: BiasDetectionParameters
92+
) => void = (data, stringified, fileName, demo, _columnsCount, params) => {
93+
setData({ data, stringified, fileName, demo, params });
9194
};
9295

9396
useEffect(() => {
@@ -108,15 +111,8 @@ export default function BiasDetection() {
108111
if (pythonCode && data.stringified.length >= 0 && initialised) {
109112
sendData(data.stringified);
110113
}
111-
if (data.demo) {
112-
onRun({
113-
iterations: 3,
114-
clusterSize: 3,
115-
targetColumn: 'FP',
116-
dataType: 'numeric',
117-
higherIsBetter: true,
118-
isDemo: true,
119-
});
114+
if (data.demo && data.params) {
115+
onRun(data.params);
120116
}
121117
}, [initialised, data]);
122118

0 commit comments

Comments
 (0)