Skip to content

Commit 55df62e

Browse files
nimrod-gileadicopybara-github
authored andcommitted
Add an iterator interface for TimeSpline class.
PiperOrigin-RevId: 604373518 Change-Id: I4a1fc781f6db12c73a2b3036b43039a8e311335f
1 parent 2cf2877 commit 55df62e

File tree

3 files changed

+246
-1
lines changed

3 files changed

+246
-1
lines changed

mjpc/spline/spline.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,22 @@ TimeSpline::ConstNode TimeSpline::NodeAt(int index) const {
5252
return ConstNode(times_[index], values_.data() + values_index_, dim_);
5353
}
5454

55+
TimeSpline::iterator TimeSpline::begin() {
56+
return TimeSpline::iterator(this, 0);
57+
}
58+
59+
TimeSpline::iterator TimeSpline::end() {
60+
return TimeSpline::iterator(this, times_.size());
61+
}
62+
63+
TimeSpline::const_iterator TimeSpline::cbegin() const {
64+
return TimeSpline::const_iterator(this, 0);
65+
}
66+
67+
TimeSpline::const_iterator TimeSpline::cend() const {
68+
return TimeSpline::const_iterator(this, times_.size());
69+
}
70+
5571
// Set Interpolation
5672
void TimeSpline::SetInterpolation(SplineInterpolation interpolation) {
5773
interpolation_ = interpolation;

mjpc/spline/spline.h

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
#include <array>
1919
#include <cstddef>
2020
#include <deque>
21+
#include <iterator>
22+
#include <type_traits>
2123
#include <vector>
2224

25+
#include <absl/log/check.h>
2326
#include <absl/types/span.h>
2427

2528
namespace mjpc::spline {
@@ -79,6 +82,145 @@ class TimeSpline {
7982
using Node = NodeT<double>;
8083
using ConstNode = NodeT<const double>;
8184

85+
// Iterator type for TimeSpline.
86+
// SplineType is TimeSpline or const TimeSpline.
87+
// NodeType is Node or ConstNode.
88+
template <typename SplineType, typename NodeType>
89+
class IteratorT {
90+
public:
91+
using iterator_category = std::random_access_iterator_tag;
92+
using value_type = typename std::remove_cv_t<NodeType>;
93+
using difference_type = int;
94+
using pointer = NodeType*;
95+
using reference = NodeType&;
96+
97+
IteratorT(SplineType* spline = nullptr, int index = 0)
98+
: spline_(spline), index_(index) {
99+
if (spline_ != nullptr && index_ != spline->Size()) {
100+
node_ = spline->NodeAt(index_);
101+
}
102+
}
103+
104+
// Copyable, Movable.
105+
IteratorT<SplineType, NodeType>(
106+
const IteratorT<SplineType, NodeType>& other) = default;
107+
IteratorT<SplineType, NodeType>& operator=(
108+
const IteratorT<SplineType, NodeType>& other) = default;
109+
IteratorT<SplineType, NodeType>(IteratorT<SplineType, NodeType>&& other) =
110+
default;
111+
IteratorT<SplineType, NodeType>& operator=(
112+
IteratorT<SplineType, NodeType>&& other) = default;
113+
114+
reference operator*() { return node_; }
115+
116+
pointer operator->() { return &node_; }
117+
pointer operator->() const { return &node_; }
118+
119+
IteratorT<SplineType, NodeType>& operator++() {
120+
++index_;
121+
node_ = index_ == spline_->Size() ? NodeType() : spline_->NodeAt(index_);
122+
return *this;
123+
}
124+
125+
IteratorT<SplineType, NodeType> operator++(int) {
126+
IteratorT<SplineType, NodeType> tmp = *this;
127+
++(*this);
128+
return tmp;
129+
}
130+
131+
IteratorT<SplineType, NodeType>& operator--() {
132+
--index_;
133+
node_ = spline_->NodeAt(index_);
134+
return *this;
135+
}
136+
137+
IteratorT<SplineType, NodeType> operator--(int) {
138+
IteratorT<SplineType, NodeType> tmp = *this;
139+
--(*this);
140+
return tmp;
141+
}
142+
143+
IteratorT<SplineType, NodeType>& operator+=(difference_type n) {
144+
if (n != 0) {
145+
index_ += n;
146+
node_ =
147+
index_ == spline_->Size() ? NodeType() : spline_->NodeAt(index_);
148+
}
149+
return *this;
150+
}
151+
152+
IteratorT<SplineType, NodeType>& operator-=(difference_type n) {
153+
return *this += -n;
154+
}
155+
156+
IteratorT<SplineType, NodeType> operator+(difference_type n) const {
157+
IteratorT<SplineType, NodeType> tmp(*this);
158+
tmp += n;
159+
return tmp;
160+
}
161+
162+
IteratorT<SplineType, NodeType> operator-(difference_type n) const {
163+
IteratorT<SplineType, NodeType> tmp(*this);
164+
tmp -= n;
165+
return tmp;
166+
}
167+
168+
friend IteratorT<SplineType, NodeType> operator+(
169+
difference_type n, const IteratorT<SplineType, NodeType>& it) {
170+
return it + n;
171+
}
172+
173+
friend difference_type operator-(const IteratorT<SplineType, NodeType>& x,
174+
const IteratorT<SplineType, NodeType>& y) {
175+
CHECK_EQ(x.spline_, y.spline_)
176+
<< "Comparing iterators from different splines";
177+
if (x != y) return (x.index_ - y.index_);
178+
return 0;
179+
}
180+
181+
NodeType operator[](difference_type n) const { return *(*this + n); }
182+
183+
friend bool operator==(const IteratorT<SplineType, NodeType>& x,
184+
const IteratorT<SplineType, NodeType>& y) {
185+
return x.spline_ == y.spline_ && x.index_ == y.index_;
186+
}
187+
188+
friend bool operator!=(const IteratorT<SplineType, NodeType>& x,
189+
const IteratorT<SplineType, NodeType>& y) {
190+
return !(x == y);
191+
}
192+
193+
friend bool operator<(const IteratorT<SplineType, NodeType>& x,
194+
const IteratorT<SplineType, NodeType>& y) {
195+
CHECK_EQ(x.spline_, y.spline_)
196+
<< "Comparing iterators from different splines";
197+
return x.index_ < y.index_;
198+
}
199+
200+
friend bool operator>(const IteratorT<SplineType, NodeType>& x,
201+
const IteratorT<SplineType, NodeType>& y) {
202+
return y < x;
203+
}
204+
205+
friend bool operator<=(const IteratorT<SplineType, NodeType>& x,
206+
const IteratorT<SplineType, NodeType>& y) {
207+
return !(y < x);
208+
}
209+
210+
friend bool operator>=(const IteratorT<SplineType, NodeType>& x,
211+
const IteratorT<SplineType, NodeType>& y) {
212+
return !(x < y);
213+
}
214+
215+
private:
216+
SplineType* spline_ = nullptr;
217+
int index_ = 0;
218+
NodeType node_;
219+
};
220+
221+
using iterator = IteratorT<TimeSpline, Node>;
222+
using const_iterator = IteratorT<const TimeSpline, ConstNode>;
223+
82224
// Returns the number of nodes in the spline.
83225
std::size_t Size() const;
84226

@@ -88,6 +230,13 @@ class TimeSpline {
88230
Node NodeAt(int index);
89231
ConstNode NodeAt(int index) const;
90232

233+
// Returns an iterator that iterates over spline nodes in time order.
234+
// Callers must not mutate `time`, but can modify values in `values`.
235+
iterator begin();
236+
iterator end();
237+
const_iterator cbegin() const;
238+
const_iterator cend() const;
239+
91240
void SetInterpolation(SplineInterpolation interpolation);
92241
SplineInterpolation Interpolation() const;
93242

mjpc/test/spline/spline_test.cc

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,10 @@ TEST_P(TimeSplineReserveTest, CopyConstructor) {
299299
EXPECT_EQ(spline2.Size(), 4);
300300

301301
// overwrite values in original spline
302-
spline.Clear();
302+
for (TimeSpline::Node& n : spline) {
303+
n.values()[0] = 3.0;
304+
n.values()[1] = 4.0;
305+
}
303306

304307
// spline2 should be unaffected
305308
EXPECT_THAT(spline2.Sample(1.5), ElementsAre(2.0, 3.0));
@@ -366,6 +369,83 @@ TEST(TimeSplineTest, Dim0) {
366369
spline.Sample(1, absl::MakeSpan(values.data(), values.size()));
367370
}
368371

372+
template <typename T>
373+
T BeginIterator(TimeSpline& spline) {
374+
if constexpr (std::is_same_v<T, TimeSpline::const_iterator>) {
375+
return spline.cbegin();
376+
} else {
377+
return spline.begin();
378+
}
379+
}
380+
381+
template <typename T>
382+
T EndIterator(TimeSpline& spline) {
383+
if constexpr (std::is_same_v<T, TimeSpline::const_iterator>) {
384+
return spline.cend();
385+
} else {
386+
return spline.end();
387+
}
388+
}
389+
390+
template <typename T>
391+
void TestIterator() {
392+
// Tests that the iterator type T complies with random_access_iterator_tag.
393+
TimeSpline spline(/*dim=*/2);
394+
spline.Reserve(10);
395+
spline.AddNode(1.0, {1.0, 2.0});
396+
spline.AddNode(2.0, {3.0, 4.0});
397+
spline.AddNode(3.0, {5.0, 6.0});
398+
399+
T it = BeginIterator<T>(spline);
400+
EXPECT_EQ(it->values()[0], 1.0);
401+
EXPECT_EQ((it + 2)->values()[0], 5.0);
402+
EXPECT_EQ((2 + it)->values()[0], 5.0);
403+
EXPECT_EQ(EndIterator<T>(spline) - it, 3);
404+
EXPECT_EQ((EndIterator<T>(spline) - 1) - it, 2);
405+
406+
it++;
407+
EXPECT_EQ(it->values()[0], 3.0);
408+
++it;
409+
EXPECT_EQ(it->values()[0], 5.0);
410+
it--;
411+
EXPECT_EQ(it->values()[0], 3.0);
412+
--it;
413+
EXPECT_EQ(it->values()[0], 1.0);
414+
415+
EXPECT_LT(it, it + 2);
416+
EXPECT_LE(it, it + 2);
417+
EXPECT_LE(it, it + 0);
418+
EXPECT_GT(it + 2, it);
419+
EXPECT_GE(it + 2, it);
420+
EXPECT_GE(it + 0, it);
421+
EXPECT_EQ((it + 2) - 2, it);
422+
423+
TimeSpline spline2 = spline;
424+
EXPECT_NE(BeginIterator<T>(spline), BeginIterator<T>(spline2))
425+
<< "Iterators from different splines should not be equal";
426+
427+
// Copy constructor
428+
T it_copy = it;
429+
EXPECT_EQ(it, it_copy);
430+
431+
auto node2 = it[2];
432+
EXPECT_EQ(node2.values()[0], 5.0);
433+
434+
// Iterators are swappable
435+
it_copy += spline.Size();
436+
std::swap(it_copy, it);
437+
EXPECT_EQ(it_copy, BeginIterator<T>(spline));
438+
EXPECT_EQ(it, EndIterator<T>(spline));
439+
}
440+
441+
TEST(TimeSplineTest, Iterator) {
442+
TestIterator<TimeSpline::iterator>();
443+
}
444+
445+
TEST(TimeSplineTest, ConstIterator) {
446+
TestIterator<TimeSpline::const_iterator>();
447+
}
448+
369449
INSTANTIATE_TEST_SUITE_P(
370450
TimeSplineAllInterpolations, TimeSplineAllInterpolationsTest,
371451
testing::ValuesIn<TimeSplineTestCase>({

0 commit comments

Comments
 (0)