1515#include " mjpc/spline/spline.h"
1616
1717#include < algorithm>
18+ #include < array>
1819#include < cstddef>
1920#include < utility>
2021#include < vector>
2526
2627namespace mjpc ::spline {
2728
28- TimeSpline::TimeSpline (int dim, int initial_capacity) : dim_(dim) {
29+ TimeSpline::TimeSpline (int dim, SplineInterpolation interpolation,
30+ int initial_capacity)
31+ : interpolation_(interpolation), dim_(dim) {
2932 values_.resize (initial_capacity * dim); // Reserve space for node values
3033}
3134
@@ -49,6 +52,13 @@ TimeSpline::ConstNode TimeSpline::NodeAt(int index) const {
4952 return ConstNode (times_[index], values_.data () + values_index_, dim_);
5053}
5154
55+ // Set Interpolation
56+ void TimeSpline::SetInterpolation (SplineInterpolation interpolation) {
57+ interpolation_ = interpolation;
58+ }
59+
60+ SplineInterpolation TimeSpline::Interpolation () const { return interpolation_; }
61+
5262int TimeSpline::Dim () const { return dim_; }
5363
5464// Reserves memory for at least num_nodes. If the spline already contains
@@ -97,8 +107,36 @@ void TimeSpline::Sample(double time, absl::Span<double> values) const {
97107 }
98108
99109 auto lower = upper - 1 ;
100- ConstNode n = NodeAt (lower - times_.begin ());
101- std::copy (n.values ().begin (), n.values ().end (), values.begin ());
110+ double t = (time - *lower) / (*upper - *lower);
111+ ConstNode lower_node = NodeAt (lower - times_.begin ());
112+ ConstNode upper_node = NodeAt (upper - times_.begin ());
113+ switch (interpolation_) {
114+ case SplineInterpolation::kZeroSpline :
115+ std::copy (lower_node.values ().begin (), lower_node.values ().end (),
116+ values.begin ());
117+ return ;
118+ case SplineInterpolation::kLinearSpline :
119+ for (int i = 0 ; i < dim_; i++) {
120+ values[i] =
121+ lower_node.values ().at (i) * (1 - t) + upper_node.values ().at (i) * t;
122+ }
123+ return ;
124+ case SplineInterpolation::kCubicSpline : {
125+ std::array<double , 4 > coefficients =
126+ CubicCoefficients (time, lower - times_.begin ());
127+ for (int i = 0 ; i < dim_; i++) {
128+ double p0 = lower_node.values ().at (i);
129+ double m0 = Slope (lower - times_.begin (), i);
130+ double m1 = Slope (upper - times_.begin (), i);
131+ double p1 = upper_node.values ().at (i);
132+ values[i] = coefficients[0 ] * p0 + coefficients[1 ] * m0 +
133+ coefficients[2 ] * p1 + coefficients[3 ] * m1;
134+ }
135+ return ;
136+ }
137+ default :
138+ CHECK (false ) << " Unknown interpolation: " << interpolation_;
139+ }
102140}
103141
104142std::vector<double > TimeSpline::Sample (double time) const {
@@ -113,8 +151,15 @@ int TimeSpline::DiscardBefore(double time) {
113151 if (last_node == times_.begin ()) {
114152 return 0 ;
115153 }
116- last_node--;
117154
155+ // If using cubic interpolation, include not just the last node before `time`,
156+ // but the one before that.
157+ int keep_nodes = interpolation_ == SplineInterpolation::kCubicSpline ? 1 : 0 ;
158+ last_node--;
159+ while (last_node != times_.begin () && keep_nodes) {
160+ last_node--;
161+ keep_nodes--;
162+ }
118163 int nodes_to_remove = last_node - times_.begin ();
119164
120165 times_.erase (times_.begin (), last_node);
@@ -175,4 +220,43 @@ TimeSpline::Node TimeSpline::AddNode(double time,
175220 }
176221 return new_node;
177222}
223+
224+ std::array<double , 4 > TimeSpline::CubicCoefficients (
225+ double time, int lower_node_index) const {
226+ std::array<double , 4 > coefficients;
227+ int upper_node_index = lower_node_index + 1 ;
228+ CHECK (upper_node_index != times_.size ())
229+ << " CubicCoefficients shouldn't be called for boundary conditions." ;
230+ double lower = times_[lower_node_index];
231+ double upper = times_[upper_node_index];
232+ double t = (time - lower) / (upper - lower);
233+
234+ coefficients[0 ] = 2.0 * t*t*t - 3.0 * t*t + 1.0 ;
235+ coefficients[1 ] =
236+ (t*t*t - 2.0 * t*t + t) * (upper - lower);
237+ coefficients[2 ] = -2.0 * t*t*t + 3 * t*t;
238+ coefficients[3 ] = (t*t*t - t*t) * (upper - lower);
239+
240+ return coefficients;
241+ }
242+
243+ double TimeSpline::Slope (int node_index, int value_index) const {
244+ ConstNode node = NodeAt (node_index);
245+ if (node_index == 0 ) {
246+ ConstNode next = NodeAt (node_index + 1 );
247+ // one-sided finite-diff
248+ return (next.values ().at (value_index) - node.values ().at (value_index)) /
249+ (next.time () - node.time ());
250+ }
251+ ConstNode prev = NodeAt (node_index - 1 );
252+ if (node_index == times_.size () - 1 ) {
253+ return (node.values ().at (value_index) - prev.values ().at (value_index)) /
254+ (node.time () - prev.time ());
255+ }
256+ ConstNode next = NodeAt (node_index + 1 );
257+ return 0.5 * (next.values ().at (value_index) - node.values ().at (value_index)) /
258+ (next.time () - node.time ()) +
259+ 0.5 * (node.values ().at (value_index) - prev.values ().at (value_index)) /
260+ (node.time () - prev.time ());
261+ }
178262} // namespace mjpc::spline
0 commit comments