Skip to content

Commit f0eabeb

Browse files
authored
Merge pull request #27 from iterorganization/bugfix/fix-implicit-linking
Fix bug where start/end values are not updated correctly after initialization
2 parents b28326c + f71b58d commit f0eabeb

File tree

3 files changed

+83
-51
lines changed

3 files changed

+83
-51
lines changed

tests/tendencies/test_base.py

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import numpy as np
12
import pytest
23
from pytest import approx
34

45
from tests.utils import filter_kwargs
5-
from waveform_editor.tendencies.base import BaseTendency
6+
from waveform_editor.tendencies.linear import LinearTendency
7+
from waveform_editor.tendencies.periodic.sine_wave import SineWaveTendency
68

79

810
@pytest.mark.parametrize(
@@ -22,7 +24,7 @@
2224
(None, 0, None, 0, 1, 1, True),
2325
],
2426
)
25-
def test_first_base_tendency(
27+
def test_first_tendency(
2628
start,
2729
duration,
2830
end,
@@ -33,15 +35,15 @@ def test_first_base_tendency(
3335
):
3436
"""Test validity of the created base tendency when it is the first tendency."""
3537
kwargs = filter_kwargs(user_start=start, user_duration=duration, user_end=end)
36-
base_tendency = BaseTendency(**kwargs)
38+
tendency = LinearTendency(**kwargs)
3739

38-
assert base_tendency.start == approx(expected_start)
39-
assert base_tendency.duration == approx(expected_duration)
40-
assert base_tendency.end == approx(expected_end)
40+
assert tendency.start == approx(expected_start)
41+
assert tendency.duration == approx(expected_duration)
42+
assert tendency.end == approx(expected_end)
4143
if has_error:
42-
assert base_tendency.annotations
44+
assert tendency.annotations
4345
else:
44-
assert not base_tendency.annotations
46+
assert not tendency.annotations
4547

4648

4749
@pytest.mark.parametrize(
@@ -61,7 +63,7 @@ def test_first_base_tendency(
6163
(None, 0, None, 10, 1, 11, True),
6264
],
6365
)
64-
def test_second_base_tendency(
66+
def test_second_tendency(
6567
start,
6668
duration,
6769
end,
@@ -71,41 +73,37 @@ def test_second_base_tendency(
7173
has_error,
7274
):
7375
"""Test validity of the created base tendency when it is the second tendency."""
74-
prev_tendency = BaseTendency(user_start=0, user_end=10)
76+
prev_tendency = LinearTendency(user_start=0, user_end=10)
7577
kwargs = filter_kwargs(user_start=start, user_duration=duration, user_end=end)
76-
base_tendency = BaseTendency(**kwargs)
77-
base_tendency.set_previous_tendency(prev_tendency)
78-
prev_tendency.set_next_tendency(base_tendency)
78+
tendency = LinearTendency(**kwargs)
79+
tendency.set_previous_tendency(prev_tendency)
80+
prev_tendency.set_next_tendency(tendency)
7981

80-
assert base_tendency.start == approx(expected_start)
81-
assert base_tendency.duration == approx(expected_duration)
82-
assert base_tendency.end == approx(expected_end)
82+
assert tendency.start == approx(expected_start)
83+
assert tendency.duration == approx(expected_duration)
84+
assert tendency.end == approx(expected_end)
8385
if has_error:
84-
assert base_tendency.annotations
86+
assert tendency.annotations
8587
else:
86-
assert not base_tendency.annotations
88+
assert not tendency.annotations
8789

8890

8991
def test_suggestion():
9092
"""Test if suggestions are provided for miswritten keywords."""
91-
base_tendency = BaseTendency(user_starrt=0, user_duuration=5, user_ennd=10)
92-
assert base_tendency.annotations
93-
assert base_tendency.start == 0
94-
assert base_tendency.duration == 1
95-
assert base_tendency.end == 1
96-
assert any(
97-
"start" in annotation["text"] for annotation in base_tendency.annotations
98-
)
99-
assert any(
100-
"duration" in annotation["text"] for annotation in base_tendency.annotations
101-
)
102-
assert any("end" in annotation["text"] for annotation in base_tendency.annotations)
93+
tendency = LinearTendency(user_starrt=0, user_duuration=5, user_ennd=10)
94+
assert tendency.annotations
95+
assert tendency.start == 0
96+
assert tendency.duration == 1
97+
assert tendency.end == 1
98+
assert any("start" in annotation["text"] for annotation in tendency.annotations)
99+
assert any("duration" in annotation["text"] for annotation in tendency.annotations)
100+
assert any("end" in annotation["text"] for annotation in tendency.annotations)
103101

104102

105103
def test_gap():
106104
"""Test if a gap between 2 tendencies is encountered."""
107-
t1 = BaseTendency(user_start=0, user_duration=5, user_end=5)
108-
t2 = BaseTendency(user_start=15, user_duration=5, user_end=20)
105+
t1 = LinearTendency(user_start=0, user_duration=5, user_end=5)
106+
t2 = LinearTendency(user_start=15, user_duration=5, user_end=20)
109107
t2.set_previous_tendency(t1)
110108
t1.set_next_tendency(t2)
111109
assert not t1.annotations
@@ -122,8 +120,8 @@ def test_gap():
122120

123121
def test_overlap():
124122
"""Test if an overlap between 2 tendencies is encountered."""
125-
t1 = BaseTendency(user_start=0, user_duration=5, user_end=5)
126-
t2 = BaseTendency(user_start=3, user_duration=5, user_end=8)
123+
t1 = LinearTendency(user_start=0, user_duration=5, user_end=5)
124+
t2 = LinearTendency(user_start=3, user_duration=5, user_end=8)
127125
t2.set_previous_tendency(t1)
128126
t1.set_next_tendency(t2)
129127
assert not t1.annotations
@@ -139,8 +137,8 @@ def test_overlap():
139137

140138

141139
def test_declarative_assignments():
142-
t1 = BaseTendency(user_duration=10)
143-
t2 = BaseTendency(user_duration=5)
140+
t1 = LinearTendency(user_duration=10)
141+
t2 = LinearTendency(user_duration=5)
144142
t2.set_previous_tendency(t1)
145143

146144
assert t1.end == approx(10)
@@ -153,3 +151,33 @@ def test_declarative_assignments():
153151
assert t2.end == approx(20)
154152
assert not t1.annotations
155153
assert not t2.annotations
154+
155+
156+
def test_float_error():
157+
"""Don't raise gap annotations when times don't match due to floating point
158+
precision."""
159+
t1 = SineWaveTendency(user_duration=1.7)
160+
t2 = LinearTendency(user_duration=2)
161+
t2.set_previous_tendency(t1)
162+
t1.set_next_tendency(t2)
163+
164+
# t2 starts at 1.7000000000000002 due to floating point precision error
165+
assert t1.end != t2.start
166+
assert np.isclose(t1.end, t2.start)
167+
assert not t1.annotations
168+
assert not t2.annotations
169+
170+
171+
def test_implicit_start_value():
172+
"""Test if start value matches previous end value."""
173+
t1 = SineWaveTendency(user_duration=1.75, user_base=8, user_amplitude=2)
174+
t2 = LinearTendency(user_duration=2, user_to=2)
175+
t2.set_previous_tendency(t1)
176+
t1.set_next_tendency(t2)
177+
178+
assert t1.end_value == 6
179+
assert t2.start_value == 6
180+
_, t1_val = t1.get_value([1.75])
181+
assert t1_val == 6
182+
_, t2_val = t2.get_value([1.75])
183+
assert t2_val == 6

waveform_editor/tendencies/base.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(self, **kwargs):
115115
self._handle_error(error)
116116

117117
self._handle_unknown_kwargs(unknown_kwargs)
118+
self.values_changed = True
118119

119120
def _handle_error(self, error):
120121
"""Handle exceptions raised by param assignment and add them as annotations.
@@ -179,18 +180,20 @@ def set_previous_tendency(self, prev_tendency):
179180
# If the tendency is the first tendency of a repeated tendency, it is linked to
180181
# the last tendency in the repeated tendency. In this case we can ignore this
181182
# error.
182-
if self.prev_tendency.end > self.start and not self.is_first_repeated:
183-
error_msg = (
184-
f"The end of the previous tendency ({self.prev_tendency.end})\nis "
185-
f"later than the start of the current tendency ({self.start}).\n"
186-
)
187-
self.annotations.add(self.line_number, error_msg)
188-
elif self.prev_tendency.end < self.start:
189-
error_msg = (
190-
"Previous tendency ends before the start of the current tendency.\n"
191-
"The values inbetween the tendencies will be linearly interpolated.\n"
192-
)
193-
self.annotations.add(self.line_number, error_msg, is_warning=True)
183+
if not np.isclose(self.prev_tendency.end, self.start):
184+
if self.prev_tendency.end > self.start and not self.is_first_repeated:
185+
error_msg = (
186+
f"The end of the previous tendency ({self.prev_tendency.end})\nis "
187+
f"later than the start of the current tendency ({self.start}).\n"
188+
)
189+
self.annotations.add(self.line_number, error_msg)
190+
elif self.prev_tendency.end < self.start:
191+
error_msg = (
192+
"Previous tendency ends before the start of the current tendency.\n"
193+
"The values inbetween the tendencies will be linearly interpolated."
194+
"\n"
195+
)
196+
self.annotations.add(self.line_number, error_msg, is_warning=True)
194197

195198
self.param.trigger("annotations")
196199

waveform_editor/tendencies/repeat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ class RepeatTendency(BaseTendency):
1212

1313
def __init__(self, **kwargs):
1414
waveform = kwargs.pop("user_waveform", []) or []
15-
super().__init__(**kwargs)
16-
1715
from waveform_editor.waveform import Waveform
1816

1917
self.waveform = Waveform(waveform=waveform, is_repeated=True)
18+
super().__init__(**kwargs)
2019
if not self.waveform.tendencies:
2120
error_msg = "There are no tendencies in the repeated waveform.\n"
2221
self.annotations.add(self.line_number, error_msg)
@@ -68,7 +67,7 @@ def get_value(
6867
Tuple containing the time and its tendency values.
6968
"""
7069
if not self.waveform.tendencies:
71-
return np.array([]), np.array([])
70+
return np.array([0]), np.array([0])
7271
length = self.waveform.calc_length()
7372
if time is None:
7473
time, values = self.waveform.get_value()
@@ -103,6 +102,8 @@ def get_derivative(self, time: np.ndarray) -> np.ndarray:
103102
Returns:
104103
numpy array containing the derivatives
105104
"""
105+
if not self.waveform.tendencies:
106+
return np.array([0])
106107
length = self.waveform.calc_length()
107108
relative_times = (time - self.start) % length
108109
derivatives = self.waveform.get_derivative(relative_times)

0 commit comments

Comments
 (0)