Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 64 additions & 36 deletions tests/tendencies/test_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import pytest
from pytest import approx

from tests.utils import filter_kwargs
from waveform_editor.tendencies.base import BaseTendency
from waveform_editor.tendencies.linear import LinearTendency
from waveform_editor.tendencies.periodic.sine_wave import SineWaveTendency


@pytest.mark.parametrize(
Expand All @@ -22,7 +24,7 @@
(None, 0, None, 0, 1, 1, True),
],
)
def test_first_base_tendency(
def test_first_tendency(
start,
duration,
end,
Expand All @@ -33,15 +35,15 @@ def test_first_base_tendency(
):
"""Test validity of the created base tendency when it is the first tendency."""
kwargs = filter_kwargs(user_start=start, user_duration=duration, user_end=end)
base_tendency = BaseTendency(**kwargs)
tendency = LinearTendency(**kwargs)

assert base_tendency.start == approx(expected_start)
assert base_tendency.duration == approx(expected_duration)
assert base_tendency.end == approx(expected_end)
assert tendency.start == approx(expected_start)
assert tendency.duration == approx(expected_duration)
assert tendency.end == approx(expected_end)
if has_error:
assert base_tendency.annotations
assert tendency.annotations
else:
assert not base_tendency.annotations
assert not tendency.annotations


@pytest.mark.parametrize(
Expand All @@ -61,7 +63,7 @@ def test_first_base_tendency(
(None, 0, None, 10, 1, 11, True),
],
)
def test_second_base_tendency(
def test_second_tendency(
start,
duration,
end,
Expand All @@ -71,41 +73,37 @@ def test_second_base_tendency(
has_error,
):
"""Test validity of the created base tendency when it is the second tendency."""
prev_tendency = BaseTendency(user_start=0, user_end=10)
prev_tendency = LinearTendency(user_start=0, user_end=10)
kwargs = filter_kwargs(user_start=start, user_duration=duration, user_end=end)
base_tendency = BaseTendency(**kwargs)
base_tendency.set_previous_tendency(prev_tendency)
prev_tendency.set_next_tendency(base_tendency)
tendency = LinearTendency(**kwargs)
tendency.set_previous_tendency(prev_tendency)
prev_tendency.set_next_tendency(tendency)

assert base_tendency.start == approx(expected_start)
assert base_tendency.duration == approx(expected_duration)
assert base_tendency.end == approx(expected_end)
assert tendency.start == approx(expected_start)
assert tendency.duration == approx(expected_duration)
assert tendency.end == approx(expected_end)
if has_error:
assert base_tendency.annotations
assert tendency.annotations
else:
assert not base_tendency.annotations
assert not tendency.annotations


def test_suggestion():
"""Test if suggestions are provided for miswritten keywords."""
base_tendency = BaseTendency(user_starrt=0, user_duuration=5, user_ennd=10)
assert base_tendency.annotations
assert base_tendency.start == 0
assert base_tendency.duration == 1
assert base_tendency.end == 1
assert any(
"start" in annotation["text"] for annotation in base_tendency.annotations
)
assert any(
"duration" in annotation["text"] for annotation in base_tendency.annotations
)
assert any("end" in annotation["text"] for annotation in base_tendency.annotations)
tendency = LinearTendency(user_starrt=0, user_duuration=5, user_ennd=10)
assert tendency.annotations
assert tendency.start == 0
assert tendency.duration == 1
assert tendency.end == 1
assert any("start" in annotation["text"] for annotation in tendency.annotations)
assert any("duration" in annotation["text"] for annotation in tendency.annotations)
assert any("end" in annotation["text"] for annotation in tendency.annotations)


def test_gap():
"""Test if a gap between 2 tendencies is encountered."""
t1 = BaseTendency(user_start=0, user_duration=5, user_end=5)
t2 = BaseTendency(user_start=15, user_duration=5, user_end=20)
t1 = LinearTendency(user_start=0, user_duration=5, user_end=5)
t2 = LinearTendency(user_start=15, user_duration=5, user_end=20)
t2.set_previous_tendency(t1)
t1.set_next_tendency(t2)
assert not t1.annotations
Expand All @@ -122,8 +120,8 @@ def test_gap():

def test_overlap():
"""Test if an overlap between 2 tendencies is encountered."""
t1 = BaseTendency(user_start=0, user_duration=5, user_end=5)
t2 = BaseTendency(user_start=3, user_duration=5, user_end=8)
t1 = LinearTendency(user_start=0, user_duration=5, user_end=5)
t2 = LinearTendency(user_start=3, user_duration=5, user_end=8)
t2.set_previous_tendency(t1)
t1.set_next_tendency(t2)
assert not t1.annotations
Expand All @@ -139,8 +137,8 @@ def test_overlap():


def test_declarative_assignments():
t1 = BaseTendency(user_duration=10)
t2 = BaseTendency(user_duration=5)
t1 = LinearTendency(user_duration=10)
t2 = LinearTendency(user_duration=5)
t2.set_previous_tendency(t1)

assert t1.end == approx(10)
Expand All @@ -153,3 +151,33 @@ def test_declarative_assignments():
assert t2.end == approx(20)
assert not t1.annotations
assert not t2.annotations


def test_float_error():
"""Don't raise gap annotations when times don't match due to floating point
precision."""
t1 = SineWaveTendency(user_duration=1.7)
t2 = LinearTendency(user_duration=2)
t2.set_previous_tendency(t1)
t1.set_next_tendency(t2)

# t2 starts at 1.7000000000000002 due to floating point precision error
assert t1.end != t2.start
assert np.isclose(t1.end, t2.start)
assert not t1.annotations
assert not t2.annotations


def test_implicit_start_value():
"""Test if start value matches previous end value."""
t1 = SineWaveTendency(user_duration=1.75, user_base=8, user_amplitude=2)
t2 = LinearTendency(user_duration=2, user_to=2)
t2.set_previous_tendency(t1)
t1.set_next_tendency(t2)

assert t1.end_value == 6
assert t2.start_value == 6
_, t1_val = t1.get_value([1.75])
assert t1_val == 6
_, t2_val = t2.get_value([1.75])
assert t2_val == 6
27 changes: 15 additions & 12 deletions waveform_editor/tendencies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(self, **kwargs):
self._handle_error(error)

self._handle_unknown_kwargs(unknown_kwargs)
self.values_changed = True

def _handle_error(self, error):
"""Handle exceptions raised by param assignment and add them as annotations.
Expand Down Expand Up @@ -179,18 +180,20 @@ def set_previous_tendency(self, prev_tendency):
# If the tendency is the first tendency of a repeated tendency, it is linked to
# the last tendency in the repeated tendency. In this case we can ignore this
# error.
if self.prev_tendency.end > self.start and not self.is_first_repeated:
error_msg = (
f"The end of the previous tendency ({self.prev_tendency.end})\nis "
f"later than the start of the current tendency ({self.start}).\n"
)
self.annotations.add(self.line_number, error_msg)
elif self.prev_tendency.end < self.start:
error_msg = (
"Previous tendency ends before the start of the current tendency.\n"
"The values inbetween the tendencies will be linearly interpolated.\n"
)
self.annotations.add(self.line_number, error_msg, is_warning=True)
if not np.isclose(self.prev_tendency.end, self.start):
if self.prev_tendency.end > self.start and not self.is_first_repeated:
error_msg = (
f"The end of the previous tendency ({self.prev_tendency.end})\nis "
f"later than the start of the current tendency ({self.start}).\n"
)
self.annotations.add(self.line_number, error_msg)
elif self.prev_tendency.end < self.start:
error_msg = (
"Previous tendency ends before the start of the current tendency.\n"
"The values inbetween the tendencies will be linearly interpolated."
"\n"
)
self.annotations.add(self.line_number, error_msg, is_warning=True)

self.param.trigger("annotations")

Expand Down
7 changes: 4 additions & 3 deletions waveform_editor/tendencies/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ class RepeatTendency(BaseTendency):

def __init__(self, **kwargs):
waveform = kwargs.pop("user_waveform", []) or []
super().__init__(**kwargs)

from waveform_editor.waveform import Waveform

self.waveform = Waveform(waveform=waveform, is_repeated=True)
super().__init__(**kwargs)
if not self.waveform.tendencies:
error_msg = "There are no tendencies in the repeated waveform.\n"
self.annotations.add(self.line_number, error_msg)
Expand Down Expand Up @@ -68,7 +67,7 @@ def get_value(
Tuple containing the time and its tendency values.
"""
if not self.waveform.tendencies:
return np.array([]), np.array([])
return np.array([0]), np.array([0])
length = self.waveform.calc_length()
if time is None:
time, values = self.waveform.get_value()
Expand Down Expand Up @@ -103,6 +102,8 @@ def get_derivative(self, time: np.ndarray) -> np.ndarray:
Returns:
numpy array containing the derivatives
"""
if not self.waveform.tendencies:
return np.array([0])
length = self.waveform.calc_length()
relative_times = (time - self.start) % length
derivatives = self.waveform.get_derivative(relative_times)
Expand Down