|
3 | 3 | import matplotlib |
4 | 4 | import matplotlib.pyplot as plt |
5 | 5 | from matplotlib.colors import LogNorm |
| 6 | +import warnings |
6 | 7 |
|
7 | 8 | from sklearn import metrics |
| 9 | +from sklearn.exceptions import UndefinedMetricWarning |
8 | 10 | from sklearn.calibration import CalibratedClassifierCV |
9 | 11 |
|
10 | 12 | from .preprocessing import horizontal_to_camera |
@@ -97,6 +99,7 @@ def plot_bias_resolution( |
97 | 99 | binned['upper_sigma'] = grouped['rel_error'].agg(lambda s: np.percentile(s, 85)) |
98 | 100 | binned['resolution_quantiles'] = (binned.upper_sigma - binned.lower_sigma) / 2 |
99 | 101 | binned['resolution'] = grouped['rel_error'].std() |
| 102 | + binned = binned[grouped.count() > 5] # at least five events |
100 | 103 |
|
101 | 104 | for key in ('bias', 'resolution', 'resolution_quantiles'): |
102 | 105 | if matplotlib.get_backend() == 'pgf' or plt.rcParams['text.usetex']: |
@@ -335,20 +338,45 @@ def r2(group): |
335 | 338 | 'e_width': np.diff(edges), |
336 | 339 | }, index=pd.Series(np.arange(1, len(edges)), name='bin_idx')) |
337 | 340 |
|
338 | | - binned['accuracy'] = df.groupby('bin_idx').apply(accuracy) |
339 | | - binned['r2_score'] = df.groupby('bin_idx').apply(r2) |
| 341 | + r2_scores = pd.DataFrame(index=binned.index) |
| 342 | + accuracies = pd.DataFrame(index=binned.index) |
| 343 | + counts = pd.DataFrame(index=binned.index) |
| 344 | + |
| 345 | + with warnings.catch_warnings(): |
| 346 | + # warns when there are less than 2 events for calculating metrics, |
| 347 | + # but we throw those away anyways |
| 348 | + warnings.filterwarnings('ignore', category=UndefinedMetricWarning) |
| 349 | + for cv_fold, cv in df.groupby('cv_fold'): |
| 350 | + grouped = cv.groupby('bin_idx') |
| 351 | + accuracies[cv_fold] = grouped.apply(accuracy) |
| 352 | + r2_scores[cv_fold] = grouped.apply(r2) |
| 353 | + counts[cv_fold] = grouped.size() |
| 354 | + |
| 355 | + binned['r2_score'] = r2_scores.mean(axis=1) |
| 356 | + binned['r2_std'] = r2_scores.std(axis=1) |
| 357 | + binned['accuracy'] = accuracies.mean(axis=1) |
| 358 | + binned['accuracy_std'] = accuracies.std(axis=1) |
| 359 | + # at least 10 events in each crossval iteration |
| 360 | + binned['valid'] = (counts > 10).any(axis=1) |
| 361 | + binned = binned.query('valid') |
340 | 362 |
|
341 | 363 | fig = fig or plt.figure() |
342 | 364 |
|
343 | 365 | ax1 = fig.add_subplot(2, 1, 1) |
344 | 366 | ax2 = fig.add_subplot(2, 1, 2, sharex=ax1) |
345 | 367 |
|
346 | 368 | ax1.errorbar( |
347 | | - binned.e_center, binned.accuracy, xerr=binned.e_width / 2, ls='', |
| 369 | + binned.e_center, binned.accuracy, |
| 370 | + yerr=binned.accuracy_std, xerr=binned.e_width / 2, |
| 371 | + ls='', |
348 | 372 | ) |
349 | 373 | ax1.set_ylabel(r'Accuracy for $\mathrm{sgn} \mathtt{disp}$') |
350 | 374 |
|
351 | | - ax2.errorbar(binned.e_center, binned.r2_score, xerr=binned.e_width / 2, ls='') |
| 375 | + ax2.errorbar( |
| 376 | + binned.e_center, binned.r2_score, |
| 377 | + yerr=binned.r2_std, xerr=binned.e_width / 2, |
| 378 | + ls='', |
| 379 | + ) |
352 | 380 | ax2.set_ylabel(r'$r^2$ score for $|\mathtt{disp}|$') |
353 | 381 |
|
354 | 382 | ax2.set_xlabel( |
|
0 commit comments