Skip to content

Commit 454c9f8

Browse files
nimrod-gileadicopybara-github
authored andcommitted
Create TimeSpline class.
It allows one to keep a spline keyed by time, with the ability to drop old nodes in the spline. PiperOrigin-RevId: 604297920 Change-Id: Iaf3fb97716e9e9f644a7c09afe5a11bda90acf1c
1 parent 7017ba0 commit 454c9f8

File tree

7 files changed

+637
-1
lines changed

7 files changed

+637
-1
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ jobs:
6868
$cmake_extra_args
6969
- name: Build MuJoCo MPC
7070
working-directory: build
71-
run: cmake --build . --config=Release ${{ matrix.cmake_build_args }} --target mjpc agent_test agent_utilities_test cost_derivatives_test norm_test rollout_test threadpool_test trajectory_test direct_force_test direct_optimize_test direct_parameter_test direct_sensor_test direct_trajectory_test direct_utilities_test batch_filter_test batch_prior_test kalman_test unscented_test cubic_test gradient_planner_test gradient_test linear_test zero_test backward_pass_test ilqg_test robust_planner_test sampling_planner_test state_test task_test utilities_test ${{ matrix.additional_targets }}
71+
run: cmake --build . --config=Release ${{ matrix.cmake_build_args }} --target mjpc agent_test agent_utilities_test cost_derivatives_test norm_test rollout_test threadpool_test trajectory_test direct_force_test direct_optimize_test direct_parameter_test direct_sensor_test direct_trajectory_test direct_utilities_test batch_filter_test batch_prior_test kalman_test spline_test unscented_test cubic_test gradient_planner_test gradient_test linear_test zero_test backward_pass_test ilqg_test robust_planner_test sampling_planner_test state_test task_test utilities_test ${{ matrix.additional_targets }}
7272
- name: Test MuJoCo MPC
7373
working-directory: build
7474
run: ctest -C Release --output-on-failure .

mjpc/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ add_library(
120120
direct/trajectory.h
121121
direct/model_parameters.cc
122122
direct/model_parameters.h
123+
spline/spline.cc
124+
spline/spline.h
123125
app.cc
124126
app.h
125127
norm.cc

mjpc/spline/spline.cc

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
// Copyright 2024 DeepMind Technologies Limited
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "mjpc/spline/spline.h"
16+
17+
#include <algorithm>
18+
#include <cstddef>
19+
#include <utility>
20+
#include <vector>
21+
22+
#include <absl/log/check.h>
23+
#include <absl/log/log.h>
24+
#include <absl/types/span.h>
25+
26+
namespace mjpc::spline {
27+
28+
TimeSpline::TimeSpline(int dim, int initial_capacity) : dim_(dim) {
29+
values_.resize(initial_capacity * dim); // Reserve space for node values
30+
}
31+
32+
std::size_t TimeSpline::Size() const { return times_.size(); }
33+
34+
TimeSpline::Node TimeSpline::NodeAt(int index) {
35+
int values_index_ = values_begin_ + index * dim_;
36+
if (values_index_ >= values_.size()) {
37+
values_index_ -= values_.size();
38+
CHECK_LE(values_index_, values_.size());
39+
}
40+
return Node(times_[index], values_.data() + values_index_, dim_);
41+
}
42+
43+
TimeSpline::ConstNode TimeSpline::NodeAt(int index) const {
44+
int values_index_ = values_begin_ + index * dim_;
45+
if (values_index_ >= values_.size()) {
46+
values_index_ -= values_.size();
47+
CHECK_LE(values_index_, values_.size());
48+
}
49+
return ConstNode(times_[index], values_.data() + values_index_, dim_);
50+
}
51+
52+
int TimeSpline::Dim() const { return dim_; }
53+
54+
// Reserves memory for at least num_nodes. If the spline already contains
55+
// more nodes, does nothing.
56+
void TimeSpline::Reserve(int num_nodes) {
57+
if (num_nodes * dim_ <= values_.size()) {
58+
return;
59+
}
60+
if (values_begin_ < values_end_ || times_.empty()) {
61+
// Easy case: just resize the values_ vector and remap the spans if needed,
62+
// without any further data copies.
63+
values_.resize(num_nodes * dim_);
64+
} else {
65+
std::vector<double> new_values(num_nodes * dim_);
66+
// Copy all existing values to the start of the new vector
67+
std::copy(values_.begin() + values_begin_, values_.end(),
68+
new_values.begin());
69+
std::copy(values_.begin(), values_.begin() + values_end_,
70+
new_values.begin() + values_.size() - values_begin_);
71+
values_ = std::move(new_values);
72+
values_begin_ = 0;
73+
values_end_ = times_.size() * dim_;
74+
}
75+
}
76+
77+
void TimeSpline::Sample(double time, absl::Span<double> values) const {
78+
CHECK_EQ(values.size(), dim_)
79+
<< "Tried to sample " << values.size()
80+
<< " values, but the dimensionality of the spline is " << dim_;
81+
82+
if (times_.empty()) {
83+
std::fill(values.begin(), values.end(), 0.0);
84+
return;
85+
}
86+
87+
auto upper = std::upper_bound(times_.begin(), times_.end(), time);
88+
if (upper == times_.end()) {
89+
ConstNode n = NodeAt(upper - times_.begin() - 1);
90+
std::copy(n.values().begin(), n.values().end(), values.begin());
91+
return;
92+
}
93+
if (upper == times_.begin()) {
94+
ConstNode n = NodeAt(upper - times_.begin());
95+
std::copy(n.values().begin(), n.values().end(), values.begin());
96+
return;
97+
}
98+
99+
auto lower = upper - 1;
100+
ConstNode n = NodeAt(lower - times_.begin());
101+
std::copy(n.values().begin(), n.values().end(), values.begin());
102+
}
103+
104+
std::vector<double> TimeSpline::Sample(double time) const {
105+
std::vector<double> values(dim_);
106+
Sample(time, absl::MakeSpan(values));
107+
return values;
108+
}
109+
110+
int TimeSpline::DiscardBefore(double time) {
111+
// Find the first node that has n.time > time.
112+
auto last_node = std::upper_bound(times_.begin(), times_.end(), time);
113+
if (last_node == times_.begin()) {
114+
return 0;
115+
}
116+
last_node--;
117+
118+
int nodes_to_remove = last_node - times_.begin();
119+
120+
times_.erase(times_.begin(), last_node);
121+
values_begin_ += dim_ * nodes_to_remove;
122+
if (values_begin_ >= values_.size()) {
123+
values_begin_ -= values_.size();
124+
CHECK_LE(values_begin_, values_.size());
125+
}
126+
return nodes_to_remove;
127+
}
128+
129+
void TimeSpline::Clear() {
130+
times_.clear();
131+
values_begin_ = 0;
132+
values_end_ = 0;
133+
// Don't change capacity_ or reset values_.
134+
}
135+
136+
// Adds a new set of values at the given time. Implementation is only
137+
// efficient if time is later than any previously added nodes.
138+
TimeSpline::Node TimeSpline::AddNode(double time) {
139+
return AddNode(time, absl::Span<const double>()); // Default empty values
140+
}
141+
142+
TimeSpline::Node TimeSpline::AddNode(double time,
143+
absl::Span<const double> new_values) {
144+
CHECK(new_values.size() == dim_ || new_values.empty());
145+
// TODO(nimrod): Implement node insertion in the middle of the spline
146+
CHECK(times_.empty() || time > times_.back() || time < times_.front())
147+
<< "Adding nodes to the middle of the spline isn't supported.";
148+
if (times_.size() * dim_ >= values_.size()) {
149+
Reserve(times_.size() * 2);
150+
}
151+
Node new_node;
152+
if (times_.empty() || time > times_.back()) {
153+
CHECK_LE(values_end_ + dim_, values_.size());
154+
times_.push_back(time);
155+
values_end_ += dim_;
156+
if (values_end_ >= values_.size()) {
157+
CHECK_EQ(values_end_, values_.size());
158+
values_end_ -= values_.size();
159+
}
160+
new_node = NodeAt(times_.size() - 1);
161+
} else {
162+
CHECK_LT(time, times_.front());
163+
values_begin_ -= dim_;
164+
if (values_begin_ < 0) {
165+
values_begin_ += values_.size();
166+
}
167+
CHECK_LE(values_begin_ + dim_, values_.size());
168+
times_.push_front(time);
169+
new_node = NodeAt(0);
170+
}
171+
if (!new_values.empty()) {
172+
std::copy(new_values.begin(), new_values.end(), new_node.values().begin());
173+
} else {
174+
std::fill(new_node.values().begin(), new_node.values().end(), 0.0);
175+
}
176+
return new_node;
177+
}
178+
} // namespace mjpc::spline

mjpc/spline/spline.h

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Copyright 2024 DeepMind Technologies Limited
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef MJPC_MJPC_SPLINE_SPLINE_H_
16+
#define MJPC_MJPC_SPLINE_SPLINE_H_
17+
18+
#include <cstddef>
19+
#include <deque>
20+
#include <vector>
21+
22+
#include <absl/types/span.h>
23+
24+
namespace mjpc::spline {
25+
26+
// Represents a spline where values are interpolated based on time.
27+
// Allows updating the spline by adding new future points, or removing old
28+
// nodes.
29+
// This class is not thread safe and requires locking to use.
30+
class TimeSpline {
31+
public:
32+
explicit TimeSpline(int dim = 0, int initial_capacity = 1);
33+
34+
// Copyable, Movable.
35+
TimeSpline(const TimeSpline& other) = default;
36+
TimeSpline& operator=(const TimeSpline& other) = default;
37+
TimeSpline(TimeSpline&& other) = default;
38+
TimeSpline& operator=(TimeSpline&& other) = default;
39+
40+
// A view into one spline node in the spline.
41+
// Template parameter is needed to support both `double` and `const double`
42+
// views of the data.
43+
template <typename T>
44+
class NodeT {
45+
public:
46+
NodeT() : time_(0) {};
47+
NodeT(double time, T* values, int dim)
48+
: time_(time), values_(values, dim) {}
49+
50+
// Copyable, Movable.
51+
NodeT(const NodeT& other) = default;
52+
NodeT& operator=(const NodeT& other) = default;
53+
NodeT(NodeT&& other) = default;
54+
NodeT& operator=(NodeT&& other) = default;
55+
56+
double time() const { return time_; }
57+
58+
// Returns a span pointing to the spline values of the node.
59+
// This function returns a non-const span, to allow spline values to be
60+
// modified, while the time member and underlying values pointer remain
61+
// constant.
62+
absl::Span<T> values() const { return values_; }
63+
64+
private:
65+
double time_;
66+
absl::Span<T> values_;
67+
};
68+
69+
using Node = NodeT<double>;
70+
using ConstNode = NodeT<const double>;
71+
72+
// Returns the number of nodes in the spline.
73+
std::size_t Size() const;
74+
75+
// Returns the node at the given index, sorted by time. Any calls that mutate
76+
// the spline will invalidate the Node object.
77+
Node NodeAt(int index);
78+
ConstNode NodeAt(int index) const;
79+
80+
// Returns the dimensionality of interpolation values.
81+
int Dim() const;
82+
83+
// Reserves memory for at least num_nodes. If the spline already contains
84+
// more nodes, does nothing.
85+
void Reserve(int num_nodes);
86+
87+
// Interpolates values based on time, writes results to `values`.
88+
// NOTE: The current implementation does a "zero-order interpolation". Linear
89+
// and cubic interpolations will be added in a follow-up.
90+
void Sample(double time, absl::Span<double> values) const;
91+
// Interpolates values based on time, returns a vector of length Dim.
92+
std::vector<double> Sample(double time) const;
93+
94+
// Removes any old nodes that have no effect on the values at time `time`.
95+
// Returns the number of nodes removed.
96+
int DiscardBefore(double time);
97+
98+
// Removes all existing nodes.
99+
void Clear();
100+
101+
// Adds a new set of values at the given time.
102+
// This class only supports adding nodes with a time later or earlier than
103+
// all other nodes.
104+
Node AddNode(double time);
105+
Node AddNode(double time, absl::Span<const double> values);
106+
107+
private:
108+
int dim_;
109+
110+
// The time values for each node. This is kept sorted.
111+
std::deque<double> times_;
112+
113+
// The raw node values. Stored in a ring buffer, which is resized whenever
114+
// too many nodes are added.
115+
std::vector<double> values_;
116+
117+
// The index in values_ for the data of the earliest node.
118+
int values_begin_ = 0;
119+
120+
// One past the index in values_ for the end of the data of the last node.
121+
// If values_end_ == values_begin_, either there's no data (nodes_ is empty),
122+
// or the values_ buffer is full.
123+
int values_end_ = 0;
124+
};
125+
126+
} // namespace mjpc::spline
127+
128+
#endif // MJPC_MJPC_SPLINE_SPLINE_H_

mjpc/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ add_subdirectory(gradient_planner)
7171
add_subdirectory(ilqg_planner)
7272
add_subdirectory(planners/robust)
7373
add_subdirectory(sampling_planner)
74+
add_subdirectory(spline)
7475
add_subdirectory(state)
7576
add_subdirectory(tasks)
7677
add_subdirectory(utilities)

mjpc/test/spline/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
test(spline_test)
16+
target_link_libraries(spline_test gmock)

0 commit comments

Comments
 (0)