Skip to content

Commit 4000a42

Browse files
committed
add option to choose "worst" line by maximum wer area under the curve
1 parent 69dd7e5 commit 4000a42

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

notebooks/Eval/ASR_eval.ipynb

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,7 @@
760760
},
761761
{
762762
"cell_type": "code",
763-
"execution_count": 42,
763+
"execution_count": 50,
764764
"id": "a11013dc",
765765
"metadata": {},
766766
"outputs": [
@@ -947,6 +947,7 @@
947947
"REQUIRE_NATIVE = True\n",
948948
"# WORST_CRITERIA = 'max_regression'\n",
949949
"WORST_CRITERIA = 'avg_regression'\n",
950+
"# WORST_CRITERIA = 'area'\n",
950951
"INCLUDE_UNCONDITIONAL = False\n",
951952
"PLOT_TOP_TEN_DIALECTS = True\n",
952953
"# STATISTIC = 'median'\n",
@@ -1029,7 +1030,16 @@
10291030
" # Convert to DataFrame and print\n",
10301031
" reg_df = pd.DataFrame(res)\n",
10311032
"\n",
1032-
" top_lang = reg_df.loc[reg_df[WORST_CRITERIA].idxmax(), 'native_language']\n",
1033+
" if WORST_CRITERIA == 'area':\n",
1034+
" top_lang = None\n",
1035+
" max_area = -1\n",
1036+
" for l in reg_df['native_language'].unique():\n",
1037+
" area = df_d[df_d['native_language'] == l].groupby(['release_date', 'model']).agg({'wer': STATISTIC}).sum()['wer']\n",
1038+
" if area > max_area:\n",
1039+
" top_lang = l\n",
1040+
" max_area = area\n",
1041+
" else:\n",
1042+
" top_lang = reg_df.loc[reg_df[WORST_CRITERIA].idxmax(), 'native_language']\n",
10331043
" if REQUIRE_NATIVE and not (native_lang and native_lang != top_lang):\n",
10341044
" continue\n",
10351045
"\n",

0 commit comments

Comments
 (0)