Skip to content

Commit 30d4347

Browse files
authored
Add normalized root mse, valid_prediction_index
1 parent f4d88d6 commit 30d4347

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

rescomp/utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def relerr(true, pre, order=2, axis=0):
1616
return norm(true - pre, axis=axis) / norm(true, axis=axis)
1717

1818
def accduration(true, pre, tol=0.2, order="inf", axis=0):
19-
warn("The function `rescomp.accduration` is depreciated. Use `rescomp.validpredictiontime` instead")
19+
warn("The function `rescomp.accduration` is depreciated. Use `rescomp.valid_prediction_index` instead")
2020
n = pre.shape[axis]
2121
for i in range(n):
2222
if axis == 0:
@@ -29,6 +29,37 @@ def accduration(true, pre, tol=0.2, order="inf", axis=0):
2929
return i
3030
return n - 1
3131

32+
def nrmse(true, pred, axis=0):
33+
""" Normalized root mean squared error between two mxn arrays.
34+
Parameters
35+
----------
36+
true (ndarray): mxn array of the true values.
37+
pred (ndarray): mxn array of the predicted values
38+
axis (int): Can be 0 or 1. Decide which axis to compute the error
39+
Returns
40+
-------
41+
err (ndarray): If axis=0 returns length m array, if axis=1 returns length n array
42+
"""
43+
sig = np.std(true, axis=axis)
44+
other_axis = (axis + 1 ) % 2 # Sends 0 -> 1 and 1 -> 0
45+
err = np.mean( (true - pred)**2 / sig, axis=other_axis)**.5
46+
return err
47+
48+
def valid_prediction_index(err, tol):
49+
""" First index i where err[i] > tol.
50+
Parameters
51+
----------
52+
err (ndarray): One dimensional array
53+
tol (float): Max allowable error.
54+
Returns
55+
-------
56+
i (int): First index such that err[i] > tol
57+
"""
58+
for i in range(len(err)):
59+
if err[i] > tol:
60+
return i
61+
return i
62+
3263
def system_fit_error(t, U, system, order="inf"):
3364
dt = np.mean(np.diff(t))
3465
ddt = fd.FinDiff(0, dt, acc=6)

0 commit comments

Comments
 (0)