@@ -140,6 +140,63 @@ using torch::executor::compute_numel;
140140
141141#endif // Use ExecuTorch types
142142
143+ template <typename T, size_t N>
144+ class TensorAccessor {
145+ public:
146+ static TensorAccessor from_tensor (const Tensor& t) {
147+ T* ptr = nullptr ;
148+ if constexpr (std::is_const_v<T>) {
149+ ptr = t.const_data_ptr <T>();
150+ } else {
151+ ptr = t.mutable_data_ptr <T>();
152+ }
153+ return TensorAccessor (ptr, t.sizes ().data (), t.strides ().data ());
154+ }
155+
156+ TensorAccessor (
157+ T* data,
158+ const SizesType* sizes,
159+ const StridesType* strides)
160+ : data_(data), sizes_(sizes), strides_(strides) {}
161+
162+ ArrayRef<SizesType> sizes () const {
163+ return ArrayRef<SizesType>(sizes_, N);
164+ }
165+
166+ ArrayRef<StridesType> strides () const {
167+ return ArrayRef<StridesType>(strides_, N);
168+ }
169+
170+ SizesType size (size_t i) const {
171+ return sizes_[i];
172+ }
173+
174+ StridesType stride (size_t i) const {
175+ return strides_[i];
176+ }
177+
178+ T* data () {
179+ return data_;
180+ }
181+
182+ const T* data () const {
183+ return data_;
184+ }
185+
186+ TensorAccessor<T, N - 1 > operator [](size_t i) {
187+ return TensorAccessor<T, N - 1 >(data_ + strides_[0 ] * i, sizes_ + 1 , strides_ + 1 );
188+ }
189+
190+ const TensorAccessor<T, N - 1 > operator [](size_t i) const {
191+ return TensorAccessor<T, N - 1 >(data_ + strides_[0 ] * i, sizes_ + 1 , strides_ + 1 );
192+ }
193+
194+ private:
195+ T* data_;
196+ const SizesType* sizes_;
197+ const StridesType* strides_;
198+ };
199+
143200} // namespace aten
144201} // namespace executorch
145202
0 commit comments