Skip to content

Commit 9be40a5

Browse files
authored
Merge pull request #121 from marty1885/master
Make tensor support iteration
2 parents 9f18299 + 9281a71 commit 9be40a5

File tree

5 files changed

+115
-8
lines changed

5 files changed

+115
-8
lines changed

Etaler/Backends/OpenCLBackend.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,6 @@ std::shared_ptr<TensorImpl> OpenCLBackend::reverseBurst(const TensorImpl* x)
519519
std::vector<uint32_t> seed1(global_size);
520520
std::vector<uint32_t> seed2(global_size);
521521

522-
for(auto& v : seed1) v = rng();
523522
for(auto& v : seed1) v = rng();
524523

525524
auto s1 = createTensor({global_size}, DType::Int32, seed1.data());
@@ -670,7 +669,7 @@ int location_func$ID(int location)
670669

671670
replaceAll(func, "$STRIDE", to_string(x->stride()));
672671
replaceAll(func, "$BIAS", std::to_string(x->offset()));
673-
return func;
672+
return func;
674673
}
675674

676675
static std::vector<std::string> jitCopyFromView(const TensorImpl* x)
@@ -730,7 +729,7 @@ kernel void copy(global Type* restrict x, global Type* restrict y)
730729
std::shared_ptr<TensorImpl> OpenCLBackend::realize(const TensorImpl* x)
731730
{
732731
requireProperties(x, this);
733-
if(x->iscontiguous() == true)
732+
if(x->isplain() == true)
734733
return copy(x);
735734

736735
std::vector<std::string> conversion = jitCopyFromView(x);

Etaler/Core/Tensor.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ Tensor Tensor::view(svector<Range> ranges) const
191191
Shape result_shape;
192192
svector<intmax_t> offset;
193193
Shape viewed_strides = pimpl_->stride();
194+
Shape result_stride;
194195
offset.reserve(dimentions());
195196

196197
assert(viewed_strides.size() == dimentions());
@@ -220,16 +221,21 @@ Tensor Tensor::view(svector<Range> ranges) const
220221
viewed_strides[i] *= step;
221222

222223
offset.push_back(real_start);
223-
if(size != 1 || result_shape.empty() == false) //Ignore heading 1 dimentions
224+
if(size != 1 || result_shape.empty() == false) { //Ignore heading 1 dimentions
224225
result_shape.push_back(size);
226+
result_stride.push_back(viewed_strides[i]);
227+
}
225228
}
226229

227230
//If all dims are 1, thus no shape. Give it a shape
228-
if(result_shape.empty() == true)
231+
if(result_shape.empty() == true) {
232+
et_assert(result_stride.size() == result_shape.size());
229233
result_shape.push_back(1);
234+
result_stride.push_back(1);
235+
}
230236

231237
size_t initial_offset = unfold(offset, pimpl_->stride())+pimpl_->offset();
232-
return std::make_shared<TensorImpl>(pimpl_->buffer(), result_shape, viewed_strides, initial_offset);
238+
return std::make_shared<TensorImpl>(pimpl_->buffer(), result_shape, result_stride, initial_offset);
233239
}
234240

235241
Tensor et::zeros(const Shape& shape, DType dtype, Backend* backend)
@@ -346,7 +352,9 @@ Tensor et::cat(const svector<Tensor>& tensors, intmax_t dim)
346352

347353
Tensor Tensor::copy() const
348354
{
349-
return backend()->copy(pimpl());
355+
if(iscontiguous() == true)
356+
return backend()->copy(pimpl());
357+
return realize().copy();
350358
}
351359

352360
inline bool brodcastable(Shape a, Shape b)

Etaler/Core/Tensor.hpp

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,36 @@ namespace et
1717
{
1818

1919
struct Tensor;
20+
21+
template <typename T>
22+
struct ETALER_EXPORT TensorIterator
23+
{
24+
// Iterator properties
25+
using iterator_category = std::bidirectional_iterator_tag;
26+
using value_type = T;
27+
using raw_value_type = std::remove_const_t<value_type>; // extra
28+
using difference_type = intmax_t;
29+
using pointer = std::unique_ptr<raw_value_type>;
30+
using reference = T&;
31+
32+
using ThisIterator = TensorIterator<T>;
33+
TensorIterator() = default;
34+
TensorIterator(reference t, intmax_t offset = 0) : t_(&t), offset_(offset)
35+
{static_assert(std::is_same_v<raw_value_type, Tensor>); }
36+
value_type operator*() { return t_->view({offset_}); }
37+
// Unfortunatelly returning a pointer is not doable
38+
pointer operator->() { return std::make_unique<raw_value_type>(*(*this)); }
39+
bool operator==(ThisIterator rhs) const { return offset_ == rhs.offset_ && t_ == rhs.t_; }
40+
bool operator!=(ThisIterator rhs) const { return !(*this == rhs); }
41+
ThisIterator& operator++() {offset_ += 1; return *this;}
42+
ThisIterator operator++(int) {ThisIterator retval = *this; ++(*this); return retval;}
43+
ThisIterator& operator--() {offset_ -= 1; return *this;}
44+
ThisIterator operator--(int) {ThisIterator retval = *this; --(*this); return retval;}
45+
value_type* t_ = nullptr; // Using a pointer because Tensor is a incomplete type here
46+
intmax_t offset_ = 0;
47+
};
48+
49+
2050
Tensor ETALER_EXPORT brodcast_to(const Tensor& t, Shape s);
2151

2252
ETALER_EXPORT std::ostream& operator<< (std::ostream& os, const Tensor& t);
@@ -204,6 +234,17 @@ struct ETALER_EXPORT Tensor
204234
TensorImpl* operator () () {return pimpl();}
205235
const TensorImpl* operator () () const {return pimpl();}
206236

237+
using iterator = TensorIterator<Tensor>;
238+
using const_iterator = TensorIterator<const Tensor>;
239+
240+
iterator begin() { return iterator(*this, 0); }
241+
iterator back() { return iterator(*this, shape()[0]-1); }
242+
iterator end() { return iterator(*this, shape()[0]); }
243+
244+
const_iterator begin() const { return const_iterator(*this, 0); }
245+
const_iterator back() const { return const_iterator(*this, shape()[0]-1); }
246+
const_iterator end() const { return const_iterator(*this, shape()[0]); }
247+
207248
bool has_value() const {return (bool)pimpl_ && size() > 0;}
208249

209250
std::pair<Tensor, Tensor> brodcast(const Tensor& other) const;
@@ -251,7 +292,7 @@ inline Tensor realize(const Tensor& t)
251292

252293
inline Tensor ravel(const Tensor& t)
253294
{
254-
if(t.iscontiguous() == false)
295+
if(t.iscontiguous() == true)
255296
return t;
256297
return t.realize();
257298
}
@@ -313,6 +354,13 @@ inline void assign(Tensor& x, const Tensor& y)
313354
x.assign(y);
314355
}
315356

357+
inline void swap(Tensor x, Tensor y)
358+
{
359+
Tensor tmp = ravel(x).copy();
360+
x.assign(y);
361+
y.assign(tmp);
362+
}
363+
316364
Tensor ETALER_EXPORT sum(const Tensor& x, std::optional<intmax_t> dim=std::nullopt, DType dtype=DType::Unknown);
317365
Tensor ETALER_EXPORT cat(const svector<Tensor>& tensors, intmax_t dim=0);
318366
inline Tensor concat(const svector<Tensor>& tensors, intmax_t dim=0) { return cat(tensors, dim); }

Etaler/Core/TensorImpl.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ struct ETALER_EXPORT TensorImpl : public std::enable_shared_from_this<TensorImpl
5656
};
5757

5858
struct IsContingous {};
59+
struct IsPlain {};
5960

6061
template <typename Storage>
6162
struct IsDType
@@ -78,6 +79,8 @@ bool checkProperty(const TensorImpl* x, const T& value)
7879
return x->dtype() == value;
7980
else if constexpr(std::is_same_v<T, IsContingous>)
8081
return x->iscontiguous();
82+
else if constexpr(std::is_same_v<T, IsPlain>)
83+
return x->iscontiguous();
8184
else if constexpr(is_specialization<std::remove_pointer_t<std::decay_t<T>>, IsDType>::value)
8285
return (std::find(value.types.begin(), value.types.end(), x->dtype()) != value.types.end());
8386
else
@@ -99,6 +102,8 @@ void requireProperty(const TensorImpl* x, const T value, const std::string& line
99102
throw EtError(msg + ".dtype() == " + to_ctype_string(value));
100103
else if constexpr(std::is_same_v<T, IsContingous>)
101104
throw EtError(msg + ".iscontiguous() == true");
105+
else if constexpr(std::is_same_v<T, IsPlain>)
106+
throw EtError(msg + ".isplain() == true");
102107
else if constexpr(is_specialization<std::remove_pointer_t<std::decay_t<T>>, IsDType>::value) {
103108
throw EtError(msg + ".dtype() is in {" + std::accumulate(value.types.begin(), value.types.end(), std::string()
104109
, [](auto v, auto a){return v + to_ctype_string(a) + ", ";}));

tests/common_tests.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ TEST_CASE("Testing Tensor", "[Tensor]")
162162

163163
CHECK_NOTHROW(requireProperties(ones(Shape{1}, DType::Int32).pimpl(), IsContingous()));
164164
CHECK_THROWS(requireProperties(ones(Shape{4,4}, DType::Int32).view({range(2), range(2)}).pimpl(), IsContingous()));
165+
166+
CHECK_NOTHROW(requireProperties(ones(Shape{1}, DType::Int32).pimpl(), IsPlain()));
167+
CHECK_NOTHROW(requireProperties(ones(Shape{1}, DType::Int32).view({0}).pimpl(), IsPlain()));
168+
CHECK_THROWS(requireProperties(ones(Shape{4,4}, DType::Int32).view({range(2), range(2)}).pimpl(), IsPlain()));
165169
}
166170

167171
SECTION("Views") {
@@ -208,6 +212,13 @@ TEST_CASE("Testing Tensor", "[Tensor]")
208212
CHECK(realize(r).isSame(pred));
209213
}
210214

215+
SECTION("View of views") {
216+
Tensor t = ones({4, 4});
217+
Tensor v1 = t[{3}];
218+
Tensor v2 = v1[{all()}];
219+
CHECK(v2.size() == 4);
220+
}
221+
211222
SECTION("View write back") {
212223
Tensor q = t.view({range(2),range(2)});
213224
CHECK_THROWS(q.assign(ones({5,5})));
@@ -286,6 +297,42 @@ TEST_CASE("Testing Tensor", "[Tensor]")
286297
// item() should fail because q is not a scalar
287298
CHECK_THROWS(q.item<int>());
288299
}
300+
301+
SECTION("iterator") {
302+
Tensor t = ones({3, 4});
303+
Tensor q = zeros({3, 4});
304+
STATIC_REQUIRE(std::is_same_v<Tensor::iterator::value_type, Tensor>);
305+
306+
// Tensor::iterator should be bideractional
307+
// Reference: http://www.cplusplus.com/reference/iterator/BidirectionalIterator/
308+
STATIC_REQUIRE(std::is_default_constructible_v<Tensor::iterator>);
309+
STATIC_REQUIRE(std::is_copy_constructible_v<Tensor::iterator>);
310+
STATIC_REQUIRE(std::is_copy_assignable_v<Tensor::iterator>);
311+
STATIC_REQUIRE(std::is_destructible_v<Tensor::iterator>);
312+
CHECK(t.begin() != t.end());
313+
CHECK(t.begin() == t.begin());
314+
CHECK((*t.begin()).shape() == Shape{4});
315+
CHECK(t.begin()->shape() == Shape{4});
316+
auto it1 = t.begin(), it2 = t.begin();
317+
it1++;
318+
++it2;
319+
CHECK(it1 == it2);
320+
--it1;
321+
it2--;
322+
CHECK(it1 == it2);
323+
324+
swap(*t.begin(), *q.begin());
325+
CHECK(t[{0}].isSame(zeros({4})));
326+
327+
int num_iteration = 0;
328+
for(auto s : t) {
329+
CHECK(s.shape() == Shape({4}));
330+
s.assign(constant({4}, 42));
331+
num_iteration += 1;
332+
}
333+
CHECK(num_iteration == t.shape()[0]);
334+
CHECK(t.sum().item<int>() == 42*t.size());
335+
}
289336
}
290337

291338
TEST_CASE("Testing Encoders", "[Encoder]")

0 commit comments

Comments
 (0)