Skip to content

Commit 947d7ac

Browse files
nimrod-gileadicopybara-github
authored andcommitted
Add support for linear and cubic interpolation in TimeSpline class.
PiperOrigin-RevId: 604338935 Change-Id: Ic9c86113b6163c4387c0d9bcc7a8698032a5d2df
1 parent 454c9f8 commit 947d7ac

File tree

3 files changed

+198
-19
lines changed

3 files changed

+198
-19
lines changed

mjpc/spline/spline.cc

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mjpc/spline/spline.h"
1616

1717
#include <algorithm>
18+
#include <array>
1819
#include <cstddef>
1920
#include <utility>
2021
#include <vector>
@@ -25,7 +26,9 @@
2526

2627
namespace mjpc::spline {
2728

28-
TimeSpline::TimeSpline(int dim, int initial_capacity) : dim_(dim) {
29+
TimeSpline::TimeSpline(int dim, SplineInterpolation interpolation,
30+
int initial_capacity)
31+
: interpolation_(interpolation), dim_(dim) {
2932
values_.resize(initial_capacity * dim); // Reserve space for node values
3033
}
3134

@@ -49,6 +52,13 @@ TimeSpline::ConstNode TimeSpline::NodeAt(int index) const {
4952
return ConstNode(times_[index], values_.data() + values_index_, dim_);
5053
}
5154

55+
// Set Interpolation
56+
void TimeSpline::SetInterpolation(SplineInterpolation interpolation) {
57+
interpolation_ = interpolation;
58+
}
59+
60+
SplineInterpolation TimeSpline::Interpolation() const { return interpolation_; }
61+
5262
int TimeSpline::Dim() const { return dim_; }
5363

5464
// Reserves memory for at least num_nodes. If the spline already contains
@@ -97,8 +107,36 @@ void TimeSpline::Sample(double time, absl::Span<double> values) const {
97107
}
98108

99109
auto lower = upper - 1;
100-
ConstNode n = NodeAt(lower - times_.begin());
101-
std::copy(n.values().begin(), n.values().end(), values.begin());
110+
double t = (time - *lower) / (*upper - *lower);
111+
ConstNode lower_node = NodeAt(lower - times_.begin());
112+
ConstNode upper_node = NodeAt(upper - times_.begin());
113+
switch (interpolation_) {
114+
case SplineInterpolation::kZeroSpline:
115+
std::copy(lower_node.values().begin(), lower_node.values().end(),
116+
values.begin());
117+
return;
118+
case SplineInterpolation::kLinearSpline:
119+
for (int i = 0; i < dim_; i++) {
120+
values[i] =
121+
lower_node.values().at(i) * (1 - t) + upper_node.values().at(i) * t;
122+
}
123+
return;
124+
case SplineInterpolation::kCubicSpline: {
125+
std::array<double, 4> coefficients =
126+
CubicCoefficients(time, lower - times_.begin());
127+
for (int i = 0; i < dim_; i++) {
128+
double p0 = lower_node.values().at(i);
129+
double m0 = Slope(lower - times_.begin(), i);
130+
double m1 = Slope(upper - times_.begin(), i);
131+
double p1 = upper_node.values().at(i);
132+
values[i] = coefficients[0] * p0 + coefficients[1] * m0 +
133+
coefficients[2] * p1 + coefficients[3] * m1;
134+
}
135+
return;
136+
}
137+
default:
138+
CHECK(false) << "Unknown interpolation: " << interpolation_;
139+
}
102140
}
103141

104142
std::vector<double> TimeSpline::Sample(double time) const {
@@ -113,8 +151,15 @@ int TimeSpline::DiscardBefore(double time) {
113151
if (last_node == times_.begin()) {
114152
return 0;
115153
}
116-
last_node--;
117154

155+
// If using cubic interpolation, include not just the last node before `time`,
156+
// but the one before that.
157+
int keep_nodes = interpolation_ == SplineInterpolation::kCubicSpline ? 1 : 0;
158+
last_node--;
159+
while (last_node != times_.begin() && keep_nodes) {
160+
last_node--;
161+
keep_nodes--;
162+
}
118163
int nodes_to_remove = last_node - times_.begin();
119164

120165
times_.erase(times_.begin(), last_node);
@@ -175,4 +220,43 @@ TimeSpline::Node TimeSpline::AddNode(double time,
175220
}
176221
return new_node;
177222
}
223+
224+
std::array<double, 4> TimeSpline::CubicCoefficients(
225+
double time, int lower_node_index) const {
226+
std::array<double, 4> coefficients;
227+
int upper_node_index = lower_node_index + 1;
228+
CHECK(upper_node_index != times_.size())
229+
<< "CubicCoefficients shouldn't be called for boundary conditions.";
230+
double lower = times_[lower_node_index];
231+
double upper = times_[upper_node_index];
232+
double t = (time - lower) / (upper - lower);
233+
234+
coefficients[0] = 2.0 * t*t*t - 3.0 * t*t + 1.0;
235+
coefficients[1] =
236+
(t*t*t - 2.0 * t*t + t) * (upper - lower);
237+
coefficients[2] = -2.0 * t*t*t + 3 * t*t;
238+
coefficients[3] = (t*t*t - t*t) * (upper - lower);
239+
240+
return coefficients;
241+
}
242+
243+
double TimeSpline::Slope(int node_index, int value_index) const {
244+
ConstNode node = NodeAt(node_index);
245+
if (node_index == 0) {
246+
ConstNode next = NodeAt(node_index + 1);
247+
// one-sided finite-diff
248+
return (next.values().at(value_index) - node.values().at(value_index)) /
249+
(next.time() - node.time());
250+
}
251+
ConstNode prev = NodeAt(node_index - 1);
252+
if (node_index == times_.size() - 1) {
253+
return (node.values().at(value_index) - prev.values().at(value_index)) /
254+
(node.time() - prev.time());
255+
}
256+
ConstNode next = NodeAt(node_index + 1);
257+
return 0.5 * (next.values().at(value_index) - node.values().at(value_index)) /
258+
(next.time() - node.time()) +
259+
0.5 * (node.values().at(value_index) - prev.values().at(value_index)) /
260+
(node.time() - prev.time());
261+
}
178262
} // namespace mjpc::spline

mjpc/spline/spline.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef MJPC_MJPC_SPLINE_SPLINE_H_
1616
#define MJPC_MJPC_SPLINE_SPLINE_H_
1717

18+
#include <array>
1819
#include <cstddef>
1920
#include <deque>
2021
#include <vector>
@@ -23,13 +24,22 @@
2324

2425
namespace mjpc::spline {
2526

27+
enum SplineInterpolation : int {
28+
kZeroSpline,
29+
kLinearSpline,
30+
kCubicSpline,
31+
};
32+
33+
2634
// Represents a spline where values are interpolated based on time.
2735
// Allows updating the spline by adding new future points, or removing old
2836
// nodes.
2937
// This class is not thread safe and requires locking to use.
3038
class TimeSpline {
3139
public:
32-
explicit TimeSpline(int dim = 0, int initial_capacity = 1);
40+
explicit TimeSpline(int dim = 0,
41+
SplineInterpolation interpolation = kZeroSpline,
42+
int initial_capacity = 1);
3343

3444
// Copyable, Movable.
3545
TimeSpline(const TimeSpline& other) = default;
@@ -72,11 +82,15 @@ class TimeSpline {
7282
// Returns the number of nodes in the spline.
7383
std::size_t Size() const;
7484

85+
7586
// Returns the node at the given index, sorted by time. Any calls that mutate
7687
// the spline will invalidate the Node object.
7788
Node NodeAt(int index);
7889
ConstNode NodeAt(int index) const;
7990

91+
void SetInterpolation(SplineInterpolation interpolation);
92+
SplineInterpolation Interpolation() const;
93+
8094
// Returns the dimensionality of interpolation values.
8195
int Dim() const;
8296

@@ -85,8 +99,6 @@ class TimeSpline {
8599
void Reserve(int num_nodes);
86100

87101
// Interpolates values based on time, writes results to `values`.
88-
// NOTE: The current implementation does a "zero-order interpolation". Linear
89-
// and cubic interpolations will be added in a follow-up.
90102
void Sample(double time, absl::Span<double> values) const;
91103
// Interpolates values based on time, returns a vector of length Dim.
92104
std::vector<double> Sample(double time) const;
@@ -105,6 +117,11 @@ class TimeSpline {
105117
Node AddNode(double time, absl::Span<const double> values);
106118

107119
private:
120+
std::array<double, 4> CubicCoefficients(double time,
121+
int lower_node_index) const;
122+
double Slope(int node_index, int value_index) const;
123+
SplineInterpolation interpolation_;
124+
108125
int dim_;
109126

110127
// The time values for each node. This is kept sorted.

0 commit comments

Comments
 (0)