@@ -22,21 +22,189 @@ namespace native {
2222using Tensor = executorch::aten::Tensor;
2323using TensorOptList = executorch::aten::ArrayRef<std::optional<Tensor>>;
2424
25- Tensor& index_Tensor_out (
25+ namespace {
26+
27+ bool check_fast_path_conditions (
28+ ET_UNUSED const Tensor& in,
29+ TensorOptList indices,
30+ size_t * dim) {
31+ bool found_index = false ;
32+ for (const auto i : c10::irange (indices.size ())) {
33+ if (indices[i].has_value ()) {
34+ *dim = i;
35+ // Fast path only supports a single non-null index tensor
36+ if (found_index) {
37+ return false ;
38+ }
39+ found_index = true ;
40+ const Tensor& index = indices[i].value ();
41+ ScalarType ix_type = index.scalar_type ();
42+ // Fast path only supports Long or Int index tensors
43+ if (ix_type != ScalarType::Long && ix_type != ScalarType::Int) {
44+ return false ;
45+ }
46+ // Fast path only supports a 1-dimensional index tensor
47+ if (index.dim () != 1 ) {
48+ return false ;
49+ }
50+ }
51+ }
52+
53+ // Fast path needs at least one non-null index tensor
54+ if (!found_index) {
55+ return false ;
56+ }
57+
58+ return true ;
59+ }
60+
61+ bool check_fast_path_args (
2662 KernelRuntimeContext& ctx,
2763 const Tensor& in,
2864 TensorOptList indices,
65+ size_t dim,
2966 Tensor& out) {
30- (void )ctx;
67+ ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (in, out));
68+
69+ ET_CHECK_OR_RETURN_FALSE (
70+ static_cast <ssize_t >(indices.size ()) <= in.dim (),
71+ " Indexing too many dimensions" );
72+
73+ const Tensor& index = indices[dim].value ();
74+
75+ bool is_valid_index = true ;
76+ ET_SWITCH_TWO_TYPES (
77+ Long, Int, index.scalar_type (), ctx, " index.Tensor" , CTYPE, [&]() {
78+ const CTYPE* const index_arr = index.const_data_ptr <CTYPE>();
79+ for (const auto i : c10::irange (index.numel ())) {
80+ if (index_arr[i] < 0 ||
81+ index_arr[i] >= static_cast <CTYPE>(in.size (dim))) {
82+ ET_LOG (
83+ Error,
84+ " Index %" PRId64
85+ " out of range for tensor with size %zd"
86+ " at dimension %zu" ,
87+ static_cast <int64_t >(index_arr[i]),
88+ in.size (dim),
89+ dim);
90+ is_valid_index = false ;
91+ break ;
92+ }
93+ }
94+ });
95+
96+ ET_CHECK_OR_RETURN_FALSE (
97+ is_valid_index,
98+ " Some index values are not within bounds of input tensor at indexed dim" );
3199
100+ return true ;
101+ }
102+
103+ void get_fast_path_index_out_target_size (
104+ const Tensor& in,
105+ TensorOptList indices,
106+ size_t dim,
107+ Tensor::SizesType* out_sizes,
108+ size_t * out_ndim) {
109+ *out_ndim = in.dim ();
110+
111+ for (const auto d : c10::irange (static_cast <size_t >(in.dim ()))) {
112+ if (d != dim) {
113+ out_sizes[d] = static_cast <Tensor::SizesType>(in.size (d));
114+ } else {
115+ out_sizes[d] =
116+ static_cast <Tensor::SizesType>(indices[dim].value ().numel ());
117+ }
118+ }
119+ }
120+
121+ Tensor& fast_path (
122+ KernelRuntimeContext& ctx,
123+ const Tensor& in,
124+ TensorOptList indices,
125+ size_t dim,
126+ Tensor& out) {
32127 ET_KERNEL_CHECK (
33- ctx, check_index_args (in, indices, out), InvalidArgument, out);
128+ ctx,
129+ check_fast_path_args (ctx, in, indices, dim, out),
130+ InvalidArgument,
131+ out);
132+
133+ const Tensor& index = indices[dim].value ();
134+ ScalarType index_type = index.scalar_type ();
135+
136+ // @lint-ignore CLANGTIDY facebook-hte-CArray
137+ Tensor::SizesType expected_size[kTensorDimensionLimit ];
138+ size_t expected_ndim = 0 ;
139+ get_fast_path_index_out_target_size (
140+ in, indices, dim, expected_size, &expected_ndim);
34141
142+ ET_KERNEL_CHECK (
143+ ctx,
144+ resize_tensor (out, {expected_size, expected_ndim}) == Error::Ok,
145+ InvalidArgument,
146+ out);
147+
148+ if (out.dim () == 0 ) {
149+ memcpy (out.mutable_data_ptr (), in.const_data_ptr (), out.nbytes ());
150+ return out;
151+ }
152+
153+ size_t leading_dims = getLeadingDims (in, dim);
154+ size_t trailing_dims = getTrailingDims (in, dim);
155+
156+ if (leading_dims == 0 || trailing_dims == 0 ) {
157+ return out;
158+ }
159+
160+ size_t in_dim_length = in.size (dim);
161+ size_t out_dim_length = out.size (dim);
162+
163+ size_t length_per_step = trailing_dims * in.element_size ();
164+
165+ const char * in_data = in.const_data_ptr <char >();
166+ char * out_data = out.mutable_data_ptr <char >();
167+
168+ // @lint-ignore CLANGTIDY facebook-hte-CArray
169+ static constexpr const char op_name[] = " index.Tensor_out" ;
170+
171+ ET_SWITCH_TWO_TYPES (Long, Int, index_type, ctx, op_name, CTYPE, [&]() {
172+ const CTYPE* const index_arr = index.const_data_ptr <CTYPE>();
173+ for (const auto i : c10::irange (leading_dims)) {
174+ const char * src = in_data + i * in_dim_length * length_per_step;
175+ char * dest = out_data + i * out_dim_length * length_per_step;
176+ for (const auto j : c10::irange (out_dim_length)) {
177+ const char * copy_src = src + index_arr[j] * length_per_step;
178+ char * copy_dest = dest + j * length_per_step;
179+ memcpy (copy_dest, copy_src, length_per_step);
180+ }
181+ }
182+ });
183+
184+ return out;
185+ }
186+
187+ } // namespace
188+
189+ Tensor& index_Tensor_out (
190+ KernelRuntimeContext& ctx,
191+ const Tensor& in,
192+ TensorOptList indices,
193+ Tensor& out) {
35194 ET_KERNEL_CHECK (
36195 ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
37196
38197 ET_KERNEL_CHECK (ctx, tensor_is_default_dim_order (in), InvalidArgument, out);
39198
199+ size_t dim = 0 ;
200+ bool is_fast_path = check_fast_path_conditions (in, indices, &dim);
201+ if (is_fast_path) {
202+ return fast_path (ctx, in, indices, dim, out);
203+ }
204+
205+ ET_KERNEL_CHECK (
206+ ctx, check_index_args (in, indices, out), InvalidArgument, out);
207+
40208 ScalarType in_type = in.scalar_type ();
41209 size_t block_count = count_index_blocks (indices);
42210
0 commit comments