Skip to content

Commit 821e35d

Browse files
committed
can force the direction of the step function
1 parent 2952629 commit 821e35d

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

eqsig/functions.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def generate_smooth_fa_spectrum(smooth_fa_frequencies, fa_frequencies, fa_spectr
604604
return smooth_fa_spectrum
605605

606606

607-
def calc_step_fn_vals_error(values, pow=1):
607+
def calc_step_fn_vals_error(values, pow=1, dir=None):
608608
"""
609609
Calculates the error function generated by fitting a step function
610610
to the values
@@ -618,6 +618,11 @@ def calc_step_fn_vals_error(values, pow=1):
618618
values: array_like
619619
pow: int
620620
The power that the error should be raised to
621+
dir: str
622+
Desired direction of the step function
623+
if 'down', then all upward steps are set to 10x maximum error
624+
if 'up', then all downward steps are set to 10x maximum error
625+
else, no modification to error
621626
622627
Returns
623628
-------
@@ -638,6 +643,12 @@ def calc_step_fn_vals_error(values, pow=1):
638643
err = np.ones_like(values)
639644
err[:-1] = err_post[1:] + err_pre[:-1]
640645
err[-1] = np.sum(np.abs(values - np.mean(values)) ** pow)
646+
if dir == 'down': # if step has to be downward, then increase error for upward steps
647+
max_err = np.max(err)
648+
err = np.where(pre_mean < post_mean, max_err * 10, err)
649+
if dir == 'up':
650+
max_err = np.max(err)
651+
err = np.where(pre_mean > post_mean, max_err * 10, err)
641652
return err
642653

643654

0 commit comments

Comments
 (0)