Skip to content

Commit 9eb2743

Browse files
committed
remove reliance on NaNs for interpolation
1 parent 60b2e2e commit 9eb2743

File tree

1 file changed

+25
-57
lines changed

1 file changed

+25
-57
lines changed

waveform_editor/waveform.py

Lines changed: 25 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -86,74 +86,42 @@ def _evaluate_tendencies(self, time, eval_derivatives=False):
8686
Returns:
8787
numpy array containing the computed values.
8888
"""
89-
values = np.full_like(time, np.nan, dtype=float)
89+
values = np.zeros_like(time, dtype=float)
9090

91-
for tendency in self.tendencies:
91+
for i, tendency in enumerate(self.tendencies):
9292
mask = (time >= tendency.start) & (time <= tendency.end)
9393
if np.any(mask):
9494
if eval_derivatives:
9595
values[mask] = tendency.get_derivative(time[mask])
9696
else:
9797
_, values[mask] = tendency.get_value(time[mask])
9898

99-
# If there still remain nans in the values, this means that there are gaps
100-
# between the tendencies. In this case we linearly interpolate between the gap
101-
# values
102-
for idx, t in enumerate(time):
103-
if np.isnan(values[idx]):
104-
# The derivatives of interpolated gaps are not calculated
105-
if eval_derivatives:
106-
values[idx] = 0
107-
else:
108-
values[idx] = self._interpolate_gap(t)
99+
# If there still remain nans in the values, this means that there are gaps
100+
# between the tendencies. In this case we linearly interpolate between the
101+
# gap values
102+
if i and tendency.prev_tendency.end < tendency.start:
103+
mask = (time < tendency.start) & (time > tendency.prev_tendency.end)
104+
if np.any(mask):
105+
if eval_derivatives:
106+
values[mask] = 0
107+
else:
108+
values[mask] = np.interp(
109+
time[mask],
110+
[tendency.prev_tendency.end, tendency.start],
111+
[tendency.prev_tendency.end_value, tendency.start_value],
112+
)
113+
# Handle extrapolation
114+
if eval_derivatives:
115+
values[time < self.tendencies[0].start] = 0
116+
values[time > self.tendencies[-1].end] = 0
117+
else:
118+
first_tendency = self.tendencies[0]
119+
values[time < first_tendency.start] = first_tendency.start_value
109120

121+
last_tendency = self.tendencies[-1]
122+
values[time > last_tendency.end] = last_tendency.end_value
110123
return values
111124

112-
def _interpolate_gap(self, t):
113-
"""Interpolates the value for a given time t based on the nearest tendencies.
114-
Also extrapolates the values if the time requested falls before the first, or
115-
after the last tendency.
116-
117-
Args:
118-
t: The time for which the value needs to be interpolated.
119-
120-
Returns:
121-
The interpolated value.
122-
"""
123-
# Find nearest tendencies before and after time t
124-
prev_tendency = max(
125-
(tend for tend in self.tendencies if tend.end <= t),
126-
default=None,
127-
key=lambda tend: tend.end,
128-
)
129-
next_tendency = min(
130-
(tend for tend in self.tendencies if tend.start >= t),
131-
default=None,
132-
key=lambda tend: tend.start,
133-
)
134-
135-
if prev_tendency and next_tendency:
136-
val_end = prev_tendency.end_value
137-
val_start = next_tendency.start_value
138-
139-
return np.interp(
140-
t, [prev_tendency.end, next_tendency.start], [val_end, val_start]
141-
)
142-
143-
# Handle extrapolation if t is before the first or after the last tendency
144-
if prev_tendency is None:
145-
next_tendency = self.tendencies[0]
146-
val_start = next_tendency.start_value
147-
return val_start
148-
149-
if next_tendency is None:
150-
prev_tendency = self.tendencies[-1]
151-
val_end = prev_tendency.end_value
152-
return val_end
153-
154-
# If no valid interpolation or extrapolation can be performed, return 0
155-
return 0.0
156-
157125
def calc_length(self):
158126
"""Returns the length of the waveform."""
159127
return self.tendencies[-1].end - self.tendencies[0].start

0 commit comments

Comments
 (0)