Skip to content

Commit 84f3c28

Browse files
committed
added AIC and BIC
1 parent 352a3e0 commit 84f3c28

File tree

2 files changed

+59
-8
lines changed

2 files changed

+59
-8
lines changed

sidpy/proc/fitter_refactor.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def residuals(p, x, y):
181181
raise ValueError(f"Unknown cov_mode={cov_mode!r}.")
182182

183183
if return_metrics:
184-
out_parts.append(np.array([np.nan, np.nan], dtype=np.float64))
184+
out_parts.append(np.array([np.nan, np.nan, np.nan, np.nan], dtype=np.float64))
185185

186186
return np.hstack(out_parts) if len(out_parts) > 1 else params
187187

@@ -199,6 +199,8 @@ def residuals(p, x, y):
199199

200200
# --- Metrics (default) appending ---
201201
if return_metrics:
202+
n = int(y_input.size)
203+
k = int(params.size)
202204
sse = float(np.sum(res.fun ** 2))
203205
rmse = float(np.sqrt(sse / y_input.size)) if y_input.size > 0 else np.nan
204206
y_mean = float(np.mean(y_input))
@@ -208,7 +210,17 @@ def residuals(p, x, y):
208210
else:
209211
r2 = 1.0 - (sse / sst)
210212

211-
out_parts.append(np.array([r2, rmse], dtype=np.float64))
213+
# AIC/BIC (up to an additive constant)
214+
if n > 0 and sse > 0.0:
215+
ll_term = n * np.log(sse / n)
216+
aic = float(ll_term + 2.0 * k)
217+
bic = float(ll_term + k * np.log(n))
218+
else:
219+
aic = np.nan
220+
bic = np.nan
221+
222+
# Order: [r2, rmse, aic, bic]
223+
out_parts.append(np.array([r2, rmse, aic, bic], dtype=np.float64))
212224

213225
# --- Covariance (optional) appending ---
214226
if cov_mode is not None:
@@ -504,7 +516,7 @@ def do_fit(self, guesses=None, use_kmeans=False, n_clusters=10,
504516
else:
505517
raise ValueError(f"Unknown cov_mode={cov_mode!r}. Use None, 'full', 'diag', or 'stderr'.")
506518

507-
metrics_size = 2 if return_metrics else 0
519+
metrics_size = 4 if return_metrics else 0
508520
out_dim = self.num_params + cov_size + metrics_size
509521

510522

@@ -627,19 +639,21 @@ def set_spatial_dims(dset):
627639

628640
# --- 2) Metrics dataset ---
629641
if return_metrics:
630-
metrics_payload = fit_dask_array[..., cursor:cursor + 2]
642+
metrics_payload = fit_dask_array[..., cursor:cursor + 4]
643+
cursor += 4
631644

632645
sid_metrics = sid.Dataset.from_array(metrics_payload, title=f"{self.dataset.title}_fit_metrics")
633646
set_spatial_dims(sid_metrics)
634647
sid_metrics.set_dimension(
635648
len(self.spat_dims),
636-
sid.Dimension(np.arange(2), name='metrics', quantity='index', dimension_type='spectral')
649+
sid.Dimension(np.arange(4), name='metrics', quantity='index', dimension_type='spectral')
637650
)
638651
sid_metrics.metadata = dict(self.metadata).copy()
639-
sid_metrics.metadata.setdefault("fit_parameters", {}).update({"metrics": ["r2", "rmse"]})
652+
sid_metrics.metadata.setdefault("fit_parameters", {}).update({"metrics": ["r2", "rmse", "aic", "bic"]})
640653
sid_metrics.provenance = {'sidpy': {'generated_from': self.dataset.title, 'parent_fit': sid_params.title}}
641654
out.append(sid_metrics)
642655

656+
643657
# --- 3) Covariance dataset (optional) ---
644658
if cov_size > 0:
645659
cov_payload = fit_dask_array[..., cursor:cursor + cov_size]

tests/proc/test_fitter_refactor.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,15 +357,52 @@ def test_metrics_perfect_fit_r2_one_rmse_zero(self):
357357
res_params, res_metrics = fitter.do_fit(use_kmeans=False, loss='linear', return_metrics=True)
358358

359359
assert res_params.shape == (nx, ny, 6)
360-
assert res_metrics.shape == (nx, ny, 2)
360+
assert res_metrics.shape == (nx, ny, 4)
361361

362362
m = np.asarray(res_metrics)
363363
r2 = m[..., 0]
364364
rmse = m[..., 1]
365-
365+
assert np.isfinite(m[..., 2]).any() # AIC
366+
assert np.isfinite(m[..., 3]).any() # BIC
366367
assert np.nanmin(r2) > 0.999
367368
assert np.nanmax(rmse) < 1e-6
368369

370+
def test_aic_bic_increase_with_noise_on_same_model():
371+
import sidpy as sid
372+
rng = np.random.default_rng(0)
373+
374+
nx, ny = 2, 2
375+
nkx, nky = 16, 16
376+
x_vec = np.linspace(-10, 10, nkx)
377+
y_vec = np.linspace(-10, 10, nky)
378+
true_p = np.array([1.8, 1.0, -1.0, 2.4, 1.7, 0.05], dtype=float)
379+
380+
clean2d = gaussian_2d((x_vec, y_vec), *true_p)
381+
clean = np.tile(clean2d, (nx, ny, 1, 1))
382+
sigma = 0.02 * np.max(np.abs(clean2d))
383+
noisy = clean + rng.normal(0.0, sigma, size=clean.shape)
384+
385+
def mk(arr, title):
386+
d = sid.Dataset.from_array(arr, title=title)
387+
d.set_dimension(0, sid.Dimension(np.arange(nx), name='X', dimension_type='spatial'))
388+
d.set_dimension(1, sid.Dimension(np.arange(ny), name='Y', dimension_type='spatial'))
389+
d.set_dimension(2, sid.Dimension(x_vec, name='kx', dimension_type='spectral'))
390+
d.set_dimension(3, sid.Dimension(y_vec, name='ky', dimension_type='spectral'))
391+
return d
392+
393+
def run(arr):
394+
fitter = SidpyFitterRefactor(mk(arr, "tmp"), gaussian_2d, gaussian_2d_guess, ind_dims=(2, 3), num_params=6)
395+
fitter.setup_calc()
396+
_, met = fitter.do_fit(use_kmeans=False, loss='linear')
397+
m = np.asarray(met)
398+
return np.nanmedian(m[..., 2]), np.nanmedian(m[..., 3]) # AIC, BIC
399+
400+
aic_clean, bic_clean = run(clean)
401+
aic_noisy, bic_noisy = run(noisy)
402+
403+
assert aic_noisy > aic_clean
404+
assert bic_noisy > bic_clean
405+
369406

370407
def test_metrics_noisy_fit_lower_r2_higher_rmse_than_clean(self):
371408

0 commit comments

Comments
 (0)