Skip to content

Commit 457f436

Browse files
authored
Merge pull request #3559 from eslickj/mpc_time_interp
Add time interpolation to mpc data
2 parents 86f5f05 + 0a25e00 commit 457f436

File tree

3 files changed

+184
-1
lines changed

3 files changed

+184
-1
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# ___________________________________________________________________________
2+
#
3+
# Pyomo: Python Optimization Modeling Objects
4+
# Copyright (c) 2008-2025
5+
# National Technology and Engineering Solutions of Sandia, LLC
6+
# Under the terms of Contract DE-NA0003525 with National Technology and
7+
# Engineering Solutions of Sandia, LLC, the U.S. Government retains certain
8+
# rights in this software.
9+
# This software is distributed under the 3-clause BSD License.
10+
# ___________________________________________________________________________
11+
12+
from bisect import bisect_right
13+
14+
15+
def _get_time_index_vec(time_set, time_data):
16+
"""Get the position index of time_data above and below the times in
17+
time_set. This can be used to find positions of points to interpolate
18+
between.
19+
20+
Parameters
21+
----------
22+
time_set: iterable
23+
Time points to locate
24+
time_data: iterable
25+
Sorted time points to locate time_set in
26+
27+
Returns
28+
-------
29+
numpy.array
30+
Position index of the first time in time_data greater than the
31+
corresponding points time_set. If a time is less than all the times
32+
in time_data return 1. If a time is greater than all times time_data
33+
set return the last index of time_data.
34+
"""
35+
pos = [None] * len(time_set)
36+
for i, t in enumerate(time_set):
37+
pos[i] = bisect_right(time_data, t)
38+
if pos[i] == 0:
39+
pos[i] = 1
40+
elif pos[i] == len(time_data):
41+
pos[i] = len(time_data) - 1
42+
return pos
43+
44+
45+
def _get_interp_expr_vec(time_set, time_data, data, indexes=None):
46+
"""Return an array of floats interpolated at the time points in time_set
47+
from data defined at time_data.
48+
49+
Parameters
50+
----------
51+
time_set: iterable
52+
Time points to locate
53+
time_data: iterable
54+
Sorted time points to locate time_set in
55+
data: iterable
56+
Data corresponding to times in time_data, must have the same
57+
length as time data.
58+
indexes: numpy.array
59+
Numpy array of position indexes of the time points to interpolate in the
60+
time data. The format is the same as returned by ``_get_time_index_vec()``.
61+
If this is None, ``_get_time_index_vec()`` is called. The reason to pass
62+
this is to avoid multiple position searches when interpolating multiple
63+
outputs with the same time points.
64+
65+
Returns
66+
-------
67+
list
68+
If data are Pyomo components, this will return Pyomo expressions.
69+
If data are floats, this will return floats.
70+
"""
71+
if indexes is None:
72+
indexes = _get_time_index_vec(time_set, time_data)
73+
expr = [None] * len(time_set)
74+
for i, (h, t) in enumerate(zip(indexes, time_set)):
75+
l = h - 1
76+
expr[i] = data[l] + (data[h] - data[l]) / (time_data[h] - time_data[l]) * (
77+
t - time_data[l]
78+
)
79+
return expr

pyomo/contrib/mpc/data/series_data.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
from pyomo.contrib.mpc.data.get_cuid import get_indexed_cuid
1616
from pyomo.contrib.mpc.data.dynamic_data_base import _is_iterable, _DynamicDataBase
1717
from pyomo.contrib.mpc.data.scalar_data import ScalarData
18-
18+
from pyomo.contrib.mpc.data.interpolation import (
19+
_get_time_index_vec,
20+
_get_interp_expr_vec,
21+
)
1922

2023
TimeSeriesTuple = namedtuple("TimeSeriesTuple", ["data", "time"])
2124

@@ -152,6 +155,50 @@ def get_data_at_time(self, time=None, tolerance=0.0):
152155
indices = indices[0]
153156
return self.get_data_at_time_indices(indices)
154157

158+
def get_interpolated_data(self, time=None, tolerance=0.0):
159+
"""
160+
Returns the data associated with the provided time point or points by
161+
linear interpolation.
162+
163+
Parameters
164+
----------
165+
time: Float or iterable
166+
The time point or points corresponding to returned data.
167+
tolerance: float
168+
Tolerance used when checking if time points are inside the data
169+
range.
170+
171+
Returns
172+
-------
173+
TimeSeriesData or ~scalar_data.ScalarData
174+
TimeSeriesData containing only the specified time points
175+
or dict mapping CUIDs to values at the specified scalar time
176+
point.
177+
178+
"""
179+
if time is None:
180+
# If time is not specified, assume we want the entire time
181+
# set. Skip all the overhead, don't create a new object, and
182+
# return self.
183+
return self
184+
is_iterable = _is_iterable(time)
185+
if not is_iterable:
186+
time = [time]
187+
for t in time:
188+
if t > self._time[-1] + tolerance or t < self._time[0] - tolerance:
189+
raise RuntimeError("Requesting interpolation outside data range.")
190+
idxs = _get_time_index_vec(time, self._time)
191+
data = {}
192+
for cuid in self._data:
193+
v = _get_interp_expr_vec(time, self._time, self._data[cuid], idxs)
194+
data[cuid] = v
195+
if is_iterable:
196+
return TimeSeriesData(data, list(time))
197+
else:
198+
for cuid in self._data:
199+
data[cuid] = data[cuid][0]
200+
return ScalarData(data)
201+
155202
def to_serializable(self):
156203
"""
157204
Convert to json-serializable object.

pyomo/contrib/mpc/data/tests/test_series_data.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,63 @@ def test_get_data_at_time_with_tolerance(self):
126126
with self.assertRaisesRegex(RuntimeError, msg):
127127
new_data = data.get_data_at_time(-0.01, tolerance=1e-3)
128128

129+
def test_get_data_interpolate(self):
130+
m = self._make_model()
131+
data_dict = {m.var[:, "A"]: [1, 2, 3], m.var[:, "B"]: [2, 4, 6]}
132+
data = TimeSeriesData(data_dict, m.time)
133+
new_data = data.get_interpolated_data(0.05)
134+
self.assertEqual(ScalarData({m.var[:, "A"]: 1.5, m.var[:, "B"]: 3}), new_data)
135+
136+
t1 = 0.05
137+
new_data = data.get_interpolated_data([t1])
138+
self.assertEqual(
139+
TimeSeriesData({m.var[:, "A"]: [1.5], m.var[:, "B"]: [3]}, [t1]), new_data
140+
)
141+
142+
new_t = [0.05, 0.15]
143+
new_data = data.get_interpolated_data(new_t)
144+
self.assertEqual(
145+
TimeSeriesData({m.var[:, "A"]: [1.5, 2.5], m.var[:, "B"]: [3, 5]}, new_t),
146+
new_data,
147+
)
148+
149+
def test_get_data_interpolate_range_check(self):
150+
m = self._make_model()
151+
data_dict = {m.var[:, "A"]: [1, 2, 3], m.var[:, "B"]: [2, 4, 6]}
152+
data = TimeSeriesData(data_dict, m.time)
153+
msg = "Requesting interpolation outside data range."
154+
with self.assertRaisesRegex(RuntimeError, msg):
155+
new_data = data.get_interpolated_data(0.2 + 1e-6)
156+
with self.assertRaisesRegex(RuntimeError, msg):
157+
new_data = data.get_interpolated_data(0.0 - 1e-6)
158+
new_data = data.get_interpolated_data(0.2 + 1e-6, tolerance=1e-5)
159+
self.assertAlmostEqual(new_data.get_data_from_key(m.var[:, "A"]), 3, 4)
160+
self.assertAlmostEqual(new_data.get_data_from_key(m.var[:, "B"]), 6, 4)
161+
162+
t1 = 0.2 + 1e-6
163+
with self.assertRaisesRegex(RuntimeError, msg):
164+
new_data = data.get_interpolated_data([t1])
165+
new_data = data.get_interpolated_data([t1], tolerance=1e-5)
166+
self.assertAlmostEqual(new_data.get_data_from_key(m.var[:, "A"])[0], 3, 4)
167+
self.assertAlmostEqual(new_data.get_data_from_key(m.var[:, "B"])[0], 6, 4)
168+
169+
new_t = [0.0 - 1e-6, 0.2 + 1e-6]
170+
with self.assertRaisesRegex(RuntimeError, msg):
171+
new_data = data.get_interpolated_data(new_t)
172+
new_data = data.get_interpolated_data(new_t, tolerance=1e-5)
173+
self.assertAlmostEqual(new_data.get_data_from_key(m.var[:, "A"])[0], 1, 4)
174+
self.assertAlmostEqual(new_data.get_data_from_key(m.var[:, "B"])[0], 2, 4)
175+
self.assertAlmostEqual(new_data.get_data_from_key(m.var[:, "A"])[1], 3, 4)
176+
self.assertAlmostEqual(new_data.get_data_from_key(m.var[:, "B"])[1], 6, 4)
177+
178+
# check that the exact endpoints don't raise an exception with 0 tol
179+
new_t = [0.0, 0.2]
180+
new_data = data.get_interpolated_data(new_t, tolerance=0.0)
181+
self.assertAlmostEqual(new_data.get_data_from_key(m.var[:, "A"])[0], 1, 4)
182+
self.assertAlmostEqual(new_data.get_data_from_key(m.var[:, "B"])[0], 2, 4)
183+
self.assertAlmostEqual(new_data.get_data_from_key(m.var[:, "A"])[1], 3, 4)
184+
self.assertAlmostEqual(new_data.get_data_from_key(m.var[:, "B"])[1], 6, 4)
185+
129186
def test_to_serializable(self):
130187
m = self._make_model()
131188
data_dict = {m.var[:, "A"]: [1, 2, 3], m.var[:, "B"]: [2, 4, 6]}

0 commit comments

Comments
 (0)