@@ -17,6 +17,36 @@ namespace et
1717{
1818
1919struct Tensor ;
20+
21+ template <typename T>
22+ struct ETALER_EXPORT TensorIterator
23+ {
24+ // Iterator properties
25+ using iterator_category = std::bidirectional_iterator_tag;
26+ using value_type = T;
27+ using raw_value_type = std::remove_const_t <value_type>; // extra
28+ using difference_type = intmax_t ;
29+ using pointer = std::unique_ptr<raw_value_type>;
30+ using reference = T&;
31+
32+ using ThisIterator = TensorIterator<T>;
33+ TensorIterator () = default ;
34+ TensorIterator (reference t, intmax_t offset = 0 ) : t_(&t), offset_(offset)
35+ {static_assert (std::is_same_v<raw_value_type, Tensor>); }
36+ value_type operator *() { return t_->view ({offset_}); }
37+ // Unfortunatelly returning a pointer is not doable
38+ pointer operator ->() { return std::make_unique<raw_value_type>(*(*this )); }
39+ bool operator ==(ThisIterator rhs) const { return offset_ == rhs.offset_ && t_ == rhs.t_ ; }
40+ bool operator !=(ThisIterator rhs) const { return !(*this == rhs); }
41+ ThisIterator& operator ++() {offset_ += 1 ; return *this ;}
42+ ThisIterator operator ++(int ) {ThisIterator retval = *this ; ++(*this ); return retval;}
43+ ThisIterator& operator --() {offset_ -= 1 ; return *this ;}
44+ ThisIterator operator --(int ) {ThisIterator retval = *this ; --(*this ); return retval;}
45+ value_type* t_ = nullptr ; // Using a pointer because Tensor is a incomplete type here
46+ intmax_t offset_ = 0 ;
47+ };
48+
49+
2050Tensor ETALER_EXPORT brodcast_to (const Tensor& t, Shape s);
2151
2252ETALER_EXPORT std::ostream& operator << (std::ostream& os, const Tensor& t);
@@ -204,6 +234,17 @@ struct ETALER_EXPORT Tensor
204234 TensorImpl* operator () () {return pimpl ();}
205235 const TensorImpl* operator () () const {return pimpl ();}
206236
237+ using iterator = TensorIterator<Tensor>;
238+ using const_iterator = TensorIterator<const Tensor>;
239+
240+ iterator begin () { return iterator (*this , 0 ); }
241+ iterator back () { return iterator (*this , shape ()[0 ]-1 ); }
242+ iterator end () { return iterator (*this , shape ()[0 ]); }
243+
244+ const_iterator begin () const { return const_iterator (*this , 0 ); }
245+ const_iterator back () const { return const_iterator (*this , shape ()[0 ]-1 ); }
246+ const_iterator end () const { return const_iterator (*this , shape ()[0 ]); }
247+
207248 bool has_value () const {return (bool )pimpl_ && size () > 0 ;}
208249
209250 std::pair<Tensor, Tensor> brodcast (const Tensor& other) const ;
@@ -251,7 +292,7 @@ inline Tensor realize(const Tensor& t)
251292
252293inline Tensor ravel (const Tensor& t)
253294{
254- if (t.iscontiguous () == false )
295+ if (t.iscontiguous () == true )
255296 return t;
256297 return t.realize ();
257298}
@@ -313,6 +354,13 @@ inline void assign(Tensor& x, const Tensor& y)
313354 x.assign (y);
314355}
315356
357+ inline void swap (Tensor x, Tensor y)
358+ {
359+ Tensor tmp = ravel (x).copy ();
360+ x.assign (y);
361+ y.assign (tmp);
362+ }
363+
316364Tensor ETALER_EXPORT sum (const Tensor& x, std::optional<intmax_t > dim=std::nullopt , DType dtype=DType::Unknown);
317365Tensor ETALER_EXPORT cat (const svector<Tensor>& tensors, intmax_t dim=0 );
318366inline Tensor concat (const svector<Tensor>& tensors, intmax_t dim=0 ) { return cat (tensors, dim); }
0 commit comments