Skip to content

Commit 7152909

Browse files
sxufacebook-github-bot
authored andcommitted
Introduce torch::executor::TensorAccessor
Summary: Replicate the TensorAccessor template from https://github.com/pytorch/pytorch/blob/fc813df1200b530d246eacc710781241c5a9dedf/aten/src/ATen/core/TensorAccessor.h#L73. Differential Revision: D66033489
1 parent e32d1b7 commit 7152909

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

runtime/core/exec_aten/exec_aten.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,63 @@ using torch::executor::compute_numel;
140140

141141
#endif // Use ExecuTorch types
142142

143+
template <typename T, size_t N>
144+
class TensorAccessor {
145+
public:
146+
static TensorAccessor from_tensor(const Tensor& t) {
147+
T* ptr = nullptr;
148+
if constexpr (std::is_const_v<T>) {
149+
ptr = t.const_data_ptr<T>();
150+
} else {
151+
ptr = t.mutable_data_ptr<T>();
152+
}
153+
return TensorAccessor(ptr, t.sizes().data(), t.strides().data());
154+
}
155+
156+
TensorAccessor(
157+
T* data,
158+
const SizesType* sizes,
159+
const StridesType* strides)
160+
: data_(data), sizes_(sizes), strides_(strides) {}
161+
162+
ArrayRef<SizesType> sizes() const {
163+
return ArrayRef<SizesType>(sizes_, N);
164+
}
165+
166+
ArrayRef<StridesType> strides() const {
167+
return ArrayRef<StridesType>(strides_, N);
168+
}
169+
170+
SizesType size(size_t i) const {
171+
return sizes_[i];
172+
}
173+
174+
StridesType stride(size_t i) const {
175+
return strides_[i];
176+
}
177+
178+
T* data() {
179+
return data_;
180+
}
181+
182+
const T* data() const {
183+
return data_;
184+
}
185+
186+
TensorAccessor<T, N - 1> operator[](size_t i) {
187+
return TensorAccessor<T, N - 1>(data_ + strides_[0] * i, sizes_ + 1, strides_ + 1);
188+
}
189+
190+
const TensorAccessor<T, N - 1> operator[](size_t i) const {
191+
return TensorAccessor<T, N - 1>(data_ + strides_[0] * i, sizes_ + 1, strides_ + 1);
192+
}
193+
194+
private:
195+
T* data_;
196+
const SizesType* sizes_;
197+
const StridesType* strides_;
198+
};
199+
143200
} // namespace aten
144201
} // namespace executorch
145202

0 commit comments

Comments
 (0)