@@ -22,21 +22,159 @@ namespace native {
22
22
using Tensor = executorch::aten::Tensor;
23
23
using TensorOptList = executorch::aten::ArrayRef<std::optional<Tensor>>;
24
24
25
- Tensor& index_Tensor_out (
25
+ namespace {
26
+
27
+ bool check_fast_path_conditions (
28
+ ET_UNUSED const Tensor& in,
29
+ TensorOptList indices,
30
+ size_t * dim) {
31
+ bool found_index = false ;
32
+ for (const auto i : c10::irange (indices.size ())) {
33
+ if (indices[i].has_value ()) {
34
+ *dim = i;
35
+ // Fast path only supports a single non-null index tensor
36
+ if (found_index) {
37
+ return false ;
38
+ }
39
+ found_index = true ;
40
+ const Tensor& index = indices[i].value ();
41
+ ScalarType ix_type = index.scalar_type ();
42
+ // Fast path only supports only supports Long or Int index tensors
43
+ if (ix_type != ScalarType::Long && ix_type != ScalarType::Int) {
44
+ return false ;
45
+ }
46
+ // Fast path only supports a 1-dimensional index tensor
47
+ if (index.dim () != 1 ) {
48
+ return false ;
49
+ }
50
+ }
51
+ }
52
+
53
+ // Fast path only supports needs at least one non-null index tensor
54
+ if (!found_index) {
55
+ return false ;
56
+ }
57
+
58
+ return true ;
59
+ }
60
+
61
+ bool check_fast_path_args (
62
+ const Tensor& in,
63
+ TensorOptList indices,
64
+ size_t dim,
65
+ Tensor& out) {
66
+ ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (in, out));
67
+
68
+ ET_CHECK_OR_RETURN_FALSE (
69
+ static_cast <ssize_t >(indices.size ()) <= in.dim (),
70
+ " Indexing too many dimensions" );
71
+
72
+ const Tensor& index = indices[dim].value ();
73
+
74
+ bool is_valid_index = true ;
75
+ ET_SWITCH_TWO_TYPES (
76
+ Long, Int, index.scalar_type (), ctx, " index_put_" , CTYPE, [&]() {
77
+ const CTYPE* const index_arr = index.const_data_ptr <CTYPE>();
78
+ for (const auto i : c10::irange (index.numel ())) {
79
+ if (index_arr[i] < 0 ||
80
+ index_arr[i] >= static_cast <CTYPE>(in.size (dim))) {
81
+ ET_LOG (
82
+ Error,
83
+ " Index %" PRId64
84
+ " out of range for tensor with size %zd"
85
+ " at dimension %zu" ,
86
+ static_cast <int64_t >(index_arr[i]),
87
+ in.size (dim),
88
+ dim);
89
+ is_valid_index = false ;
90
+ break ;
91
+ }
92
+ }
93
+ });
94
+
95
+ ET_CHECK_OR_RETURN_FALSE (
96
+ is_valid_index,
97
+ " Some index values are not within bounds of input tensor at indexed dim" );
98
+
99
+ return true ;
100
+ }
101
+
102
+ Tensor& fast_path (
26
103
KernelRuntimeContext& ctx,
27
104
const Tensor& in,
28
105
TensorOptList indices,
106
+ size_t dim,
29
107
Tensor& out) {
30
108
(void )ctx;
31
109
32
110
ET_KERNEL_CHECK (
33
- ctx, check_index_args (in, indices, out), InvalidArgument, out);
111
+ ctx, check_fast_path_args (in, indices, dim, out), InvalidArgument, out);
112
+
113
+ const Tensor& index = indices[dim].value ();
114
+ ScalarType index_type = index.scalar_type ();
115
+
116
+ if (out.dim () == 0 ) {
117
+ memcpy (out.mutable_data_ptr (), in.const_data_ptr (), out.nbytes ());
118
+ return out;
119
+ }
120
+
121
+ size_t leading_dims = getLeadingDims (in, dim);
122
+ size_t trailing_dims = getTrailingDims (in, dim);
123
+
124
+ if (leading_dims == 0 || trailing_dims == 0 ) {
125
+ return out;
126
+ }
127
+
128
+ size_t in_dim_length = in.size (dim);
129
+ size_t out_dim_length = out.size (dim);
130
+
131
+ size_t length_per_step = trailing_dims * in.element_size ();
132
+
133
+ const char * in_data = in.const_data_ptr <char >();
134
+ char * out_data = out.mutable_data_ptr <char >();
135
+
136
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
137
+ static constexpr const char op_name[] = " index.Tensor_out" ;
138
+
139
+ ET_SWITCH_TWO_TYPES (Long, Int, index_type, ctx, op_name, CTYPE, [&]() {
140
+ const CTYPE* const index_arr = index.const_data_ptr <CTYPE>();
141
+ for (const auto i : c10::irange (leading_dims)) {
142
+ const char * src = in_data + i * in_dim_length * length_per_step;
143
+ char * dest = out_data + i * out_dim_length * length_per_step;
144
+ for (const auto j : c10::irange (out_dim_length)) {
145
+ const char * copy_src = src + index_arr[j] * length_per_step;
146
+ char * copy_dest = dest + j * length_per_step;
147
+ memcpy (copy_dest, copy_src, length_per_step);
148
+ }
149
+ }
150
+ });
151
+
152
+ return out;
153
+ }
154
+
155
+ } // namespace
156
+
157
+ Tensor& index_Tensor_out (
158
+ KernelRuntimeContext& ctx,
159
+ const Tensor& in,
160
+ TensorOptList indices,
161
+ Tensor& out) {
162
+ (void )ctx;
34
163
35
164
ET_KERNEL_CHECK (
36
165
ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
37
166
38
167
ET_KERNEL_CHECK (ctx, tensor_is_default_dim_order (in), InvalidArgument, out);
39
168
169
+ size_t dim = 0 ;
170
+ bool is_fast_path = check_fast_path_conditions (in, indices, &dim);
171
+ if (is_fast_path) {
172
+ return fast_path (ctx, in, indices, dim, out);
173
+ }
174
+
175
+ ET_KERNEL_CHECK (
176
+ ctx, check_index_args (in, indices, out), InvalidArgument, out);
177
+
40
178
ScalarType in_type = in.scalar_type ();
41
179
size_t block_count = count_index_blocks (indices);
42
180
0 commit comments