Skip to content

Commit 0bb1ce8

Browse files
committed
refactor how derivatives and values are handled
1 parent 81bd7aa commit 0bb1ce8

File tree

6 files changed

+64
-90
lines changed

6 files changed

+64
-90
lines changed

waveform_editor/tendencies/base.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import abstractmethod
2+
from typing import Optional
23

34
import numpy as np
45
import param
@@ -55,6 +56,9 @@ class BaseTendency(param.Parameterized):
5556
start_value = param.Number(doc="Value at self.start")
5657
end_value = param.Number(doc="Value at self.end")
5758

59+
start_derivative = param.Number(doc="Derivative at self.start")
60+
end_derivative = param.Number(doc="Derivative at self.end")
61+
5862
time_error = param.ClassSelector(
5963
class_=Exception,
6064
default=None,
@@ -103,41 +107,24 @@ def set_next_tendency(self, next_tendency):
103107

104108
@depends("values_changed", watch=True)
105109
def _calc_start_end_values(self):
106-
self.start_value = self.get_start_value()
107-
self.end_value = self.get_end_value()
108-
109-
@abstractmethod
110-
def get_start_value(self) -> float:
111-
"""Returns the value of the tendency at the start."""
112-
return 0.0
110+
_, self.start_value = self.get_value(np.array([self.start]))
111+
_, self.start_derivative = self.get_derivative(np.array([self.start]))
113112

114-
@abstractmethod
115-
def get_end_value(self) -> float:
116-
"""Returns the value of the tendency at the end."""
117-
return 0.0
118-
119-
@abstractmethod
120-
def get_derivative_start(self) -> float:
121-
"""Returns the derivative of the tendency at the start."""
122-
return 0.0
113+
_, self.end_value = self.get_value(np.array([self.end]))
114+
_, self.end_derivative = self.get_derivative(np.array([self.end]))
123115

124116
@abstractmethod
125-
def get_derivative_end(self) -> float:
126-
"""Returns the derivative of the tendency at the end."""
127-
return 0.0
117+
def get_value(
118+
self, time: Optional[np.ndarray] = None
119+
) -> tuple[np.ndarray, np.ndarray]:
120+
"""Get the values on the provided time array."""
121+
pass
128122

129123
@abstractmethod
130-
def generate(self, time) -> tuple[np.ndarray, np.ndarray]:
131-
"""Generate time and values based on the tendency. If no time array is provided,
132-
a linearly spaced time array will be generated from the start to the end of the
133-
tendency.
134-
135-
Args:
136-
time: The time array on which to generate points.
137-
138-
Returns:
139-
Tuple containing the time and its tendency values.
140-
"""
124+
def get_derivative(
125+
self, time: Optional[np.ndarray] = None
126+
) -> tuple[np.ndarray, np.ndarray]:
127+
"""Get the derivative values on the provided time array."""
141128
pass
142129

143130
@depends(

waveform_editor/tendencies/constant.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import numpy as np
24
import param
35
from param import depends
@@ -19,8 +21,10 @@ def __init__(self, **kwargs):
1921
self.value = 0.0
2022
super().__init__(**kwargs)
2123

22-
def generate(self, time=None):
23-
"""Generate time and values based on the tendency. If no time array is provided,
24+
def get_value(
25+
self, time: Optional[np.ndarray] = None
26+
) -> tuple[np.ndarray, np.ndarray]:
27+
"""Get the values onf the provided time array. If no time array is provided,
2428
a constant line containing the start and end points will be generated.
2529
2630
Args:
@@ -34,21 +38,22 @@ def generate(self, time=None):
3438
values = self.value * np.ones(len(time))
3539
return time, values
3640

37-
def get_start_value(self) -> float:
38-
"""Returns the value of the tendency at the start."""
39-
return self.value
40-
41-
def get_end_value(self) -> float:
42-
"""Returns the value of the tendency at the end."""
43-
return self.value
41+
def get_derivative(
42+
self, time: Optional[np.ndarray] = None
43+
) -> tuple[np.ndarray, np.ndarray]:
44+
"""Get the derivative values on the provided time array. If no time array is
45+
provided, a constant line containing the start and end points will be returned.
4446
45-
def get_derivative_start(self) -> float:
46-
"""Returns the derivative of the tendency at the start."""
47-
return 0
47+
Args:
48+
time: The time array on which to generate points.
4849
49-
def get_derivative_end(self) -> float:
50-
"""Returns the derivative of the tendency at the end."""
51-
return 0
50+
Returns:
51+
Tuple containing the time and its tendency values.
52+
"""
53+
if time is None:
54+
time = np.array([self.start, self.end])
55+
derivatives = np.zeros(len(time))
56+
return time, derivatives
5257

5358
@depends(
5459
"prev_tendency.end_value",

waveform_editor/tendencies/linear.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import numpy as np
24
import param
35
from param import depends
@@ -33,8 +35,10 @@ def __init__(self, **kwargs):
3335
self.rate = 0.0
3436
super().__init__(**kwargs)
3537

36-
def generate(self, time=None):
37-
"""Generate time and values based on the tendency. If no time array is provided,
38+
def get_value(
39+
self, time: Optional[np.ndarray] = None
40+
) -> tuple[np.ndarray, np.ndarray]:
41+
"""Get the values onf the provided time array. If no time array is provided,
3842
a line containing the start and end points will be generated.
3943
4044
Args:
@@ -45,28 +49,26 @@ def generate(self, time=None):
4549
"""
4650
if time is None:
4751
time = np.array([self.start, self.end])
48-
time = np.array(time)
49-
5052
normalized_time = (time - self.start) / (self.end - self.start)
51-
5253
values = self.from_ + (self.to - self.from_) * normalized_time
5354
return time, values
5455

55-
def get_start_value(self) -> float:
56-
"""Returns the value of the tendency at the start."""
57-
return self.from_
56+
def get_derivative(
57+
self, time: Optional[np.ndarray] = None
58+
) -> tuple[np.ndarray, np.ndarray]:
59+
"""Get the derivative values on the provided time array. If no time array is
60+
a constant line containing the start and end points will be generated.
5861
59-
def get_end_value(self) -> float:
60-
"""Returns the value of the tendency at the end."""
61-
return self.to
62-
63-
def get_derivative_start(self) -> float:
64-
"""Returns the derivative of the tendency at the start."""
65-
return self.rate
62+
Args:
63+
time: The time array on which to generate points.
6664
67-
def get_derivative_end(self) -> float:
68-
"""Returns the derivative of the tendency at the end."""
69-
return self.rate
65+
Returns:
66+
Tuple containing the time and its tendency values.
67+
"""
68+
if time is None:
69+
time = np.array([self.start, self.end])
70+
derivatives = self.rate * np.ones(len(time))
71+
return time, derivatives
7072

7173
# Workaround: param doesn't like a @depends on both prev and next tendency
7274
_trigger = param.Event()
@@ -104,7 +106,7 @@ def _calc_values(self):
104106
if self.prev_tendency is None:
105107
inputs[0] = 0
106108
else:
107-
inputs[0] = self.prev_tendency.get_end_value()
109+
inputs[0] = self.prev_tendency.get_value(self.start)
108110
num_inputs += 1
109111
start_value_set = False
110112
else:
@@ -113,7 +115,7 @@ def _calc_values(self):
113115
if num_inputs < 2 and inputs[2] is None:
114116
# To value is not provided, set to from_ or next start value
115117
if self.next_tendency is not None and self.next_tendency.start_value_set:
116-
inputs[2] = self.next_tendency.get_start_value()
118+
inputs[2] = self.next_tendency.get_value(self.end)
117119
else:
118120
inputs[2] = inputs[0]
119121
num_inputs += 1

waveform_editor/waveform.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, waveform):
3333
self._process_waveform(waveform)
3434
self.calc_length()
3535

36-
def generate(self, time=None):
36+
def get_value(self, time=None):
3737
"""Generate time and values based on the tendency. If no time array is provided,
3838
a constant line containing the start and end points will be generated.
3939
@@ -47,7 +47,7 @@ def generate(self, time=None):
4747
times = []
4848
values = []
4949
for tendency in self.tendencies:
50-
time, value = tendency.generate()
50+
time, value = tendency.get_value()
5151
times.extend(time)
5252
values.extend(value)
5353
times = np.array(times)
@@ -60,30 +60,12 @@ def generate(self, time=None):
6060

6161
if np.any(mask):
6262
relevant_times = times[mask]
63-
_, generated_values = tendency.generate(relevant_times)
63+
_, generated_values = tendency.get_value(relevant_times)
6464

6565
values[mask] = generated_values
6666

6767
return times, values
6868

69-
def get_start_value(self) -> float:
70-
"""Returns the value of the tendency at the start."""
71-
return self.generate(self.start)
72-
73-
def get_end_value(self) -> float:
74-
"""Returns the value of the tendency at the end."""
75-
return self.generate(self.end)
76-
77-
def get_derivative_start(self) -> float:
78-
"""Returns the derivative of the tendency at the start."""
79-
# TODO:
80-
return 0
81-
82-
def get_derivative_end(self) -> float:
83-
"""Returns the derivative of the tendency at the end."""
84-
# TODO:
85-
return 0
86-
8769
def calc_length(self):
8870
"""Returns the length of the waveform."""
8971
return self.tendencies[-1].end - self.tendencies[0].start

waveform_editor/waveform_editor_gui.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
value="""\
1515
waveform:
1616
- {type: linear, from: 0, to: 8, duration: 5}
17-
- {type: sine-wave, base: 8, amplitude: 2, frequency: 1, duration: 4}
1817
- {type: constant, value: 8, duration: 3}
19-
- {type: smooth, from: 8, to: 0, duration: 2}
2018
""",
2119
width=600,
2220
height=1200,
@@ -38,7 +36,7 @@ def update_plot(value):
3836
yaml_parser.tendencies = []
3937
yaml_parser.parse_waveforms_from_string(value)
4038

41-
return yaml_parser.plot_tendencies()
39+
return yaml_parser.plot_tendencies(True)
4240

4341

4442
hv_dynamic_map = hv.DynamicMap(pn.bind(update_plot, value=code_editor.param.value))

waveform_editor/yaml_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def plot_tendencies(self, plot_time_points=False):
3636
Returns:
3737
A Holoviews Overlay object.
3838
"""
39-
times, values = self.waveform.generate()
39+
times, values = self.waveform.get_value()
4040

4141
overlay = hv.Overlay()
4242

0 commit comments

Comments
 (0)