@@ -22,21 +22,189 @@ namespace native {
22
22
using Tensor = executorch::aten::Tensor;
23
23
using TensorOptList = executorch::aten::ArrayRef<std::optional<Tensor>>;
24
24
25
- Tensor& index_Tensor_out (
25
+ namespace {
26
+
27
+ bool check_fast_path_conditions (
28
+ ET_UNUSED const Tensor& in,
29
+ TensorOptList indices,
30
+ size_t * dim) {
31
+ bool found_index = false ;
32
+ for (const auto i : c10::irange (indices.size ())) {
33
+ if (indices[i].has_value ()) {
34
+ *dim = i;
35
+ // Fast path only supports a single non-null index tensor
36
+ if (found_index) {
37
+ return false ;
38
+ }
39
+ found_index = true ;
40
+ const Tensor& index = indices[i].value ();
41
+ ScalarType ix_type = index.scalar_type ();
42
+ // Fast path only supports Long or Int index tensors
43
+ if (ix_type != ScalarType::Long && ix_type != ScalarType::Int) {
44
+ return false ;
45
+ }
46
+ // Fast path only supports a 1-dimensional index tensor
47
+ if (index.dim () != 1 ) {
48
+ return false ;
49
+ }
50
+ }
51
+ }
52
+
53
+ // Fast path needs at least one non-null index tensor
54
+ if (!found_index) {
55
+ return false ;
56
+ }
57
+
58
+ return true ;
59
+ }
60
+
61
+ bool check_fast_path_args (
26
62
KernelRuntimeContext& ctx,
27
63
const Tensor& in,
28
64
TensorOptList indices,
65
+ size_t dim,
29
66
Tensor& out) {
30
- (void )ctx;
67
+ ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (in, out));
68
+
69
+ ET_CHECK_OR_RETURN_FALSE (
70
+ static_cast <ssize_t >(indices.size ()) <= in.dim (),
71
+ " Indexing too many dimensions" );
72
+
73
+ const Tensor& index = indices[dim].value ();
74
+
75
+ bool is_valid_index = true ;
76
+ ET_SWITCH_TWO_TYPES (
77
+ Long, Int, index.scalar_type (), ctx, " index.Tensor" , CTYPE, [&]() {
78
+ const CTYPE* const index_arr = index.const_data_ptr <CTYPE>();
79
+ for (const auto i : c10::irange (index.numel ())) {
80
+ if (index_arr[i] < 0 ||
81
+ index_arr[i] >= static_cast <CTYPE>(in.size (dim))) {
82
+ ET_LOG (
83
+ Error,
84
+ " Index %" PRId64
85
+ " out of range for tensor with size %zd"
86
+ " at dimension %zu" ,
87
+ static_cast <int64_t >(index_arr[i]),
88
+ in.size (dim),
89
+ dim);
90
+ is_valid_index = false ;
91
+ break ;
92
+ }
93
+ }
94
+ });
95
+
96
+ ET_CHECK_OR_RETURN_FALSE (
97
+ is_valid_index,
98
+ " Some index values are not within bounds of input tensor at indexed dim" );
31
99
100
+ return true ;
101
+ }
102
+
103
+ void get_fast_path_index_out_target_size (
104
+ const Tensor& in,
105
+ TensorOptList indices,
106
+ size_t dim,
107
+ Tensor::SizesType* out_sizes,
108
+ size_t * out_ndim) {
109
+ *out_ndim = in.dim ();
110
+
111
+ for (const auto d : c10::irange (static_cast <size_t >(in.dim ()))) {
112
+ if (d != dim) {
113
+ out_sizes[d] = static_cast <Tensor::SizesType>(in.size (d));
114
+ } else {
115
+ out_sizes[d] =
116
+ static_cast <Tensor::SizesType>(indices[dim].value ().numel ());
117
+ }
118
+ }
119
+ }
120
+
121
+ Tensor& fast_path (
122
+ KernelRuntimeContext& ctx,
123
+ const Tensor& in,
124
+ TensorOptList indices,
125
+ size_t dim,
126
+ Tensor& out) {
32
127
ET_KERNEL_CHECK (
33
- ctx, check_index_args (in, indices, out), InvalidArgument, out);
128
+ ctx,
129
+ check_fast_path_args (ctx, in, indices, dim, out),
130
+ InvalidArgument,
131
+ out);
132
+
133
+ const Tensor& index = indices[dim].value ();
134
+ ScalarType index_type = index.scalar_type ();
135
+
136
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
137
+ Tensor::SizesType expected_size[kTensorDimensionLimit ];
138
+ size_t expected_ndim = 0 ;
139
+ get_fast_path_index_out_target_size (
140
+ in, indices, dim, expected_size, &expected_ndim);
34
141
142
+ ET_KERNEL_CHECK (
143
+ ctx,
144
+ resize_tensor (out, {expected_size, expected_ndim}) == Error::Ok,
145
+ InvalidArgument,
146
+ out);
147
+
148
+ if (out.dim () == 0 ) {
149
+ memcpy (out.mutable_data_ptr (), in.const_data_ptr (), out.nbytes ());
150
+ return out;
151
+ }
152
+
153
+ size_t leading_dims = getLeadingDims (in, dim);
154
+ size_t trailing_dims = getTrailingDims (in, dim);
155
+
156
+ if (leading_dims == 0 || trailing_dims == 0 ) {
157
+ return out;
158
+ }
159
+
160
+ size_t in_dim_length = in.size (dim);
161
+ size_t out_dim_length = out.size (dim);
162
+
163
+ size_t length_per_step = trailing_dims * in.element_size ();
164
+
165
+ const char * in_data = in.const_data_ptr <char >();
166
+ char * out_data = out.mutable_data_ptr <char >();
167
+
168
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
169
+ static constexpr const char op_name[] = " index.Tensor_out" ;
170
+
171
+ ET_SWITCH_TWO_TYPES (Long, Int, index_type, ctx, op_name, CTYPE, [&]() {
172
+ const CTYPE* const index_arr = index.const_data_ptr <CTYPE>();
173
+ for (const auto i : c10::irange (leading_dims)) {
174
+ const char * src = in_data + i * in_dim_length * length_per_step;
175
+ char * dest = out_data + i * out_dim_length * length_per_step;
176
+ for (const auto j : c10::irange (out_dim_length)) {
177
+ const char * copy_src = src + index_arr[j] * length_per_step;
178
+ char * copy_dest = dest + j * length_per_step;
179
+ memcpy (copy_dest, copy_src, length_per_step);
180
+ }
181
+ }
182
+ });
183
+
184
+ return out;
185
+ }
186
+
187
+ } // namespace
188
+
189
+ Tensor& index_Tensor_out (
190
+ KernelRuntimeContext& ctx,
191
+ const Tensor& in,
192
+ TensorOptList indices,
193
+ Tensor& out) {
35
194
ET_KERNEL_CHECK (
36
195
ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
37
196
38
197
ET_KERNEL_CHECK (ctx, tensor_is_default_dim_order (in), InvalidArgument, out);
39
198
199
+ size_t dim = 0 ;
200
+ bool is_fast_path = check_fast_path_conditions (in, indices, &dim);
201
+ if (is_fast_path) {
202
+ return fast_path (ctx, in, indices, dim, out);
203
+ }
204
+
205
+ ET_KERNEL_CHECK (
206
+ ctx, check_index_args (in, indices, out), InvalidArgument, out);
207
+
40
208
ScalarType in_type = in.scalar_type ();
41
209
size_t block_count = count_index_blocks (indices);
42
210
0 commit comments