Skip to content

Commit ace5d45

Browse files
committed
add tests for waveform, and add interpolation
1 parent 185d13b commit ace5d45

File tree

3 files changed

+262
-5
lines changed

3 files changed

+262
-5
lines changed

tests/test_waveform.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import numpy as np
2+
import pytest
3+
4+
from waveform_editor.tendencies.constant import ConstantTendency
5+
from waveform_editor.tendencies.linear import LinearTendency
6+
from waveform_editor.tendencies.periodic.sine_wave import SineWaveTendency
7+
from waveform_editor.tendencies.smooth import SmoothTendency
8+
from waveform_editor.waveform import Waveform
9+
10+
11+
def test_empty():
12+
waveform = Waveform()
13+
assert waveform.tendencies == []
14+
assert waveform.annotations == []
15+
16+
17+
@pytest.fixture
18+
def waveform_list():
19+
return [
20+
{"type": "linear", "from": 0, "to": 8, "duration": 5, "line_number": 1},
21+
{
22+
"type": "sine-wave",
23+
"base": 8,
24+
"amplitude": 2,
25+
"frequency": 1,
26+
"duration": 4,
27+
"line_number": 2,
28+
},
29+
{"type": "constant", "value": 8, "duration": 3, "line_number": 3},
30+
{"type": "smooth", "from": 8, "to": 0, "duration": 2, "line_number": 4},
31+
]
32+
33+
34+
@pytest.fixture
35+
def waveform(waveform_list):
36+
return Waveform(waveform=waveform_list)
37+
38+
39+
def test_annotations(waveform_list):
40+
"""Test if annotations of tendencies are passed to waveform's annotations."""
41+
waveform = Waveform(waveform=waveform_list)
42+
assert not waveform.annotations
43+
44+
waveform_list[0]["type"] = "sine-wav"
45+
waveform = Waveform(waveform=waveform_list)
46+
assert waveform.annotations
47+
48+
49+
def test_tendencies(waveform):
50+
"""Test if tendencies are of correct type."""
51+
assert isinstance(waveform.tendencies[0], LinearTendency)
52+
assert isinstance(waveform.tendencies[1], SineWaveTendency)
53+
assert isinstance(waveform.tendencies[2], ConstantTendency)
54+
assert isinstance(waveform.tendencies[3], SmoothTendency)
55+
56+
57+
def test_get_value(waveform):
58+
"""Test if get_value returns the correct values."""
59+
times = np.linspace(0, 14, 15)
60+
_, values = waveform.get_value(times)
61+
expected = [0, 1.6, 3.2, 4.8, 6.4, 8, 8, 8, 8, 8, 8, 8, 8, 4, 0]
62+
assert np.allclose(values, expected)
63+
64+
65+
def test_get_derivative(waveform):
66+
"""Test if get_derivative returns the correct values."""
67+
times = np.linspace(0, 14, 15)
68+
derivatives = waveform.get_derivative(times)
69+
fpi = 4 * np.pi
70+
expected = [1.6, 1.6, 1.6, 1.6, 1.6, fpi, fpi, fpi, fpi, 0, 0, 0, 0, -6, 0]
71+
assert np.allclose(derivatives, expected)
72+
73+
74+
def test_length(waveform):
75+
"""Test if calc_length returns the correct value."""
76+
assert waveform.calc_length() == 14
77+
78+
79+
def test_gap():
80+
"""Test if gap between tendency is interpolated."""
81+
gap_waveform = [
82+
{"type": "constant", "value": 3, "start": 0, "end": 2, "line_number": 1},
83+
{"type": "constant", "value": 5, "start": 4, "end": 5, "line_number": 2},
84+
]
85+
waveform = Waveform(waveform=gap_waveform)
86+
assert waveform.annotations
87+
times, values = waveform.get_value()
88+
assert np.allclose(times, [0, 2, 4, 5])
89+
assert np.allclose(values, [3, 3, 5, 5])
90+
91+
expected = [3, 3, 3, 3, 3, 3.5, 4, 4.5, 5, 5, 5]
92+
_, values = waveform.get_value(np.linspace(0, 5, 11))
93+
assert np.allclose(values, expected)
94+
95+
96+
def test_gap_derivative():
97+
"""Test if derivative of gap between tendency is set to zero."""
98+
gap_waveform = [
99+
{"type": "constant", "value": 3, "start": 0, "end": 2, "line_number": 1},
100+
{"type": "constant", "value": 5, "start": 4, "end": 5, "line_number": 2},
101+
]
102+
waveform = Waveform(waveform=gap_waveform)
103+
assert waveform.annotations
104+
105+
values = waveform.get_derivative(np.linspace(0, 5, 11))
106+
assert np.allclose(values, np.zeros(11))
107+
108+
109+
def test_get_value_outside(waveform):
110+
"""Test if values outside of range are clipped."""
111+
gap_waveform = [
112+
{"type": "constant", "value": 3, "start": 0, "end": 2, "line_number": 1},
113+
{"type": "constant", "value": 5, "start": 4, "end": 5, "line_number": 2},
114+
]
115+
gap_waveform = Waveform(waveform=gap_waveform)
116+
# test requesting values outside of time range
117+
_, gap_values = gap_waveform.get_value(np.linspace(-1, 0, 4))
118+
_, values = waveform.get_value(np.linspace(-5, 0, 6))
119+
assert np.allclose(gap_values, [3, 3, 3, 3])
120+
assert np.allclose(values, np.zeros(6))
121+
122+
# test requesting values outside of time range
123+
_, gap_values = gap_waveform.get_value(np.linspace(5, 6, 4))
124+
_, values = waveform.get_value(np.linspace(14, 18, 5))
125+
assert np.allclose(gap_values, [5, 5, 5, 5])
126+
assert np.allclose(values, np.zeros(5))
127+
128+
129+
def test_get_derivative_outside(waveform):
130+
"""Test if derivatives outside of range are set to zero."""
131+
gap_waveform = [
132+
{"type": "linear", "from": 3, "to": 7, "start": 0, "end": 2, "line_number": 1},
133+
{
134+
"type": "linear",
135+
"from": 6,
136+
"to": 3,
137+
"start": 4,
138+
"end": 5,
139+
"line_number": 2,
140+
},
141+
]
142+
gap_waveform = Waveform(waveform=gap_waveform)
143+
# test requesting values outside of time range
144+
gap_derivatives = gap_waveform.get_derivative(np.linspace(-1, 0, 4))
145+
derivatives = waveform.get_derivative(np.linspace(-5, 0, 6))
146+
assert np.allclose(gap_derivatives, [0, 0, 0, 2])
147+
assert np.allclose(derivatives, [0, 0, 0, 0, 0, 1.6])
148+
149+
# test requesting values outside of time range
150+
gap_derivatives = gap_waveform.get_derivative(np.linspace(5, 6, 4))
151+
derivatives = waveform.get_derivative(np.linspace(14, 18, 5))
152+
assert np.allclose(gap_derivatives, [-3, 0, 0, 0])
153+
assert np.allclose(derivatives, np.zeros(5))
154+
155+
156+
def test_overlap():
157+
"""Test values if tendencies overlap."""
158+
overlap_waveform = [
159+
{"type": "constant", "value": 3, "start": 0, "end": 2, "line_number": 1},
160+
{"type": "constant", "value": 5, "start": 1, "end": 3, "line_number": 2},
161+
]
162+
waveform = Waveform(waveform=overlap_waveform)
163+
assert waveform.annotations
164+
times, values = waveform.get_value()
165+
assert np.allclose(times, [0, 2, 1, 3])
166+
assert np.allclose(values, [3, 3, 5, 5])
167+
168+
# Later tendencies take precedence
169+
expected = [3, 3, 5, 5, 5, 5, 5]
170+
_, values = waveform.get_value(np.linspace(0, 3, 7))
171+
assert np.allclose(values, expected)
172+
173+
174+
def test_overlap_derivatives():
175+
"""Test derivatives if tendencies overlap."""
176+
overlap_waveform = [
177+
{"type": "linear", "from": 3, "to": 7, "start": 0, "end": 2, "line_number": 1},
178+
{
179+
"type": "linear",
180+
"from": 6,
181+
"to": 3,
182+
"start": 1,
183+
"end": 3,
184+
"line_number": 2,
185+
},
186+
]
187+
waveform = Waveform(waveform=overlap_waveform)
188+
assert waveform.annotations
189+
190+
# Later tendencies take precedence
191+
expected = [2, 2, -1.5, -1.5, -1.5, -1.5, -1.5]
192+
values = waveform.get_derivative(np.linspace(0, 3, 7))
193+
assert np.allclose(values, expected)

waveform_editor/tendencies/repeat.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,20 @@ def __init__(self, **kwargs):
3333
self.waveform.tendencies[0].set_previous_tendency(self.waveform.tendencies[-1])
3434
self.waveform.tendencies[-1].set_next_tendency(self.waveform.tendencies[0])
3535

36-
_, self.start_value = self.get_value(self.start)
37-
self.start_derivative = self.get_derivative(self.start)
38-
_, self.end_value = self.get_value(self.end)
39-
self.end_derivative = self.get_derivative(self.end)
36+
self._set_bounds()
4037
self.annotations.add_annotations(self.waveform.annotations)
4138

39+
def _set_bounds(self):
40+
"""Sets the start and end values, as well as derivatives"""
41+
_, start_values = self.get_value(np.array([self.start]))
42+
self.start_value = start_values[0]
43+
start_derivatives = self.get_derivative(np.array([self.start]))
44+
self.start_derivative = start_derivatives[0]
45+
_, end_values = self.get_value(np.array([self.end]))
46+
self.end_value = end_values[0]
47+
end_derivatives = self.get_derivative(np.array([self.end]))
48+
self.end_derivative = end_derivatives[0]
49+
4250
def get_value(
4351
self, time: Optional[np.ndarray] = None
4452
) -> tuple[np.ndarray, np.ndarray]:

waveform_editor/waveform.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _evaluate_tendencies(self, time, eval_derivatives=False):
8787
Returns:
8888
numpy array containing the computed values.
8989
"""
90-
values = np.zeros_like(time, dtype=float)
90+
values = np.full_like(time, np.nan, dtype=float)
9191

9292
for tendency in self.tendencies:
9393
mask = (time >= tendency.start) & (time <= tendency.end)
@@ -97,8 +97,64 @@ def _evaluate_tendencies(self, time, eval_derivatives=False):
9797
else:
9898
_, values[mask] = tendency.get_value(time[mask])
9999

100+
# If there still remain nans in the values, this means that there are gaps
101+
# between the tendencies. In this case we linearly interpolate between the gap
102+
# values
103+
for idx, t in enumerate(time):
104+
if np.isnan(values[idx]):
105+
# The derivatives of interpolated gaps are not calculated
106+
if eval_derivatives:
107+
values[idx] = 0
108+
else:
109+
values[idx] = self._interpolate_gap(t)
110+
100111
return values
101112

113+
def _interpolate_gap(self, t):
114+
"""Interpolates the value for a given time t based on the nearest tendencies.
115+
Also extrapolates the values if the time requested falls before the first, or
116+
after the last tendency.
117+
118+
Args:
119+
t: The time for which the value needs to be interpolated.
120+
121+
Returns:
122+
The interpolated value.
123+
"""
124+
# Find nearest tendencies before and after time t
125+
prev_tendency = max(
126+
(tend for tend in self.tendencies if tend.end <= t),
127+
default=None,
128+
key=lambda tend: tend.end,
129+
)
130+
next_tendency = min(
131+
(tend for tend in self.tendencies if tend.start >= t),
132+
default=None,
133+
key=lambda tend: tend.start,
134+
)
135+
136+
if prev_tendency and next_tendency:
137+
val_end = prev_tendency.end_value
138+
val_start = next_tendency.start_value
139+
140+
return np.interp(
141+
t, [prev_tendency.end, next_tendency.start], [val_end, val_start]
142+
)
143+
144+
# Handle extrapolation if t is before the first or after the last tendency
145+
if prev_tendency is None:
146+
next_tendency = self.tendencies[0]
147+
val_start = next_tendency.start_value
148+
return val_start
149+
150+
if next_tendency is None:
151+
prev_tendency = self.tendencies[-1]
152+
val_end = prev_tendency.end_value
153+
return val_end
154+
155+
# If no valid interpolation or extrapolation can be performed, return 0
156+
return 0.0
157+
102158
def calc_length(self):
103159
"""Returns the length of the waveform."""
104160
return self.tendencies[-1].end - self.tendencies[0].start

0 commit comments

Comments
 (0)