@@ -22,21 +22,159 @@ namespace native {
2222using Tensor = executorch::aten::Tensor;
2323using TensorOptList = executorch::aten::ArrayRef<std::optional<Tensor>>;
2424
25- Tensor& index_Tensor_out (
25+ namespace {
26+
27+ bool check_fast_path_conditions (
28+ ET_UNUSED const Tensor& in,
29+ TensorOptList indices,
30+ size_t * dim) {
31+ bool found_index = false ;
32+ for (const auto i : c10::irange (indices.size ())) {
33+ if (indices[i].has_value ()) {
34+ *dim = i;
35+ // Fast path only supports a single non-null index tensor
36+ if (found_index) {
37+ return false ;
38+ }
39+ found_index = true ;
40+ const Tensor& index = indices[i].value ();
41+ ScalarType ix_type = index.scalar_type ();
42+ // Fast path only supports only supports Long or Int index tensors
43+ if (ix_type != ScalarType::Long && ix_type != ScalarType::Int) {
44+ return false ;
45+ }
46+ // Fast path only supports a 1-dimensional index tensor
47+ if (index.dim () != 1 ) {
48+ return false ;
49+ }
50+ }
51+ }
52+
53+ // Fast path only supports needs at least one non-null index tensor
54+ if (!found_index) {
55+ return false ;
56+ }
57+
58+ return true ;
59+ }
60+
61+ bool check_fast_path_args (
62+ const Tensor& in,
63+ TensorOptList indices,
64+ size_t dim,
65+ Tensor& out) {
66+ ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (in, out));
67+
68+ ET_CHECK_OR_RETURN_FALSE (
69+ static_cast <ssize_t >(indices.size ()) <= in.dim (),
70+ " Indexing too many dimensions" );
71+
72+ const Tensor& index = indices[dim].value ();
73+
74+ bool is_valid_index = true ;
75+ ET_SWITCH_TWO_TYPES (
76+ Long, Int, index.scalar_type (), ctx, " index_put_" , CTYPE, [&]() {
77+ const CTYPE* const index_arr = index.const_data_ptr <CTYPE>();
78+ for (const auto i : c10::irange (index.numel ())) {
79+ if (index_arr[i] < 0 ||
80+ index_arr[i] >= static_cast <CTYPE>(in.size (dim))) {
81+ ET_LOG (
82+ Error,
83+ " Index %" PRId64
84+ " out of range for tensor with size %zd"
85+ " at dimension %zu" ,
86+ static_cast <int64_t >(index_arr[i]),
87+ in.size (dim),
88+ dim);
89+ is_valid_index = false ;
90+ break ;
91+ }
92+ }
93+ });
94+
95+ ET_CHECK_OR_RETURN_FALSE (
96+ is_valid_index,
97+ " Some index values are not within bounds of input tensor at indexed dim" );
98+
99+ return true ;
100+ }
101+
102+ Tensor& fast_path (
26103 KernelRuntimeContext& ctx,
27104 const Tensor& in,
28105 TensorOptList indices,
106+ size_t dim,
29107 Tensor& out) {
30108 (void )ctx;
31109
32110 ET_KERNEL_CHECK (
33- ctx, check_index_args (in, indices, out), InvalidArgument, out);
111+ ctx, check_fast_path_args (in, indices, dim, out), InvalidArgument, out);
112+
113+ const Tensor& index = indices[dim].value ();
114+ ScalarType index_type = index.scalar_type ();
115+
116+ if (out.dim () == 0 ) {
117+ memcpy (out.mutable_data_ptr (), in.const_data_ptr (), out.nbytes ());
118+ return out;
119+ }
120+
121+ size_t leading_dims = getLeadingDims (in, dim);
122+ size_t trailing_dims = getTrailingDims (in, dim);
123+
124+ if (leading_dims == 0 || trailing_dims == 0 ) {
125+ return out;
126+ }
127+
128+ size_t in_dim_length = in.size (dim);
129+ size_t out_dim_length = out.size (dim);
130+
131+ size_t length_per_step = trailing_dims * in.element_size ();
132+
133+ const char * in_data = in.const_data_ptr <char >();
134+ char * out_data = out.mutable_data_ptr <char >();
135+
136+ // @lint-ignore CLANGTIDY facebook-hte-CArray
137+ static constexpr const char op_name[] = " index.Tensor_out" ;
138+
139+ ET_SWITCH_TWO_TYPES (Long, Int, index_type, ctx, op_name, CTYPE, [&]() {
140+ const CTYPE* const index_arr = index.const_data_ptr <CTYPE>();
141+ for (const auto i : c10::irange (leading_dims)) {
142+ const char * src = in_data + i * in_dim_length * length_per_step;
143+ char * dest = out_data + i * out_dim_length * length_per_step;
144+ for (const auto j : c10::irange (out_dim_length)) {
145+ const char * copy_src = src + index_arr[j] * length_per_step;
146+ char * copy_dest = dest + j * length_per_step;
147+ memcpy (copy_dest, copy_src, length_per_step);
148+ }
149+ }
150+ });
151+
152+ return out;
153+ }
154+
155+ } // namespace
156+
157+ Tensor& index_Tensor_out (
158+ KernelRuntimeContext& ctx,
159+ const Tensor& in,
160+ TensorOptList indices,
161+ Tensor& out) {
162+ (void )ctx;
34163
35164 ET_KERNEL_CHECK (
36165 ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
37166
38167 ET_KERNEL_CHECK (ctx, tensor_is_default_dim_order (in), InvalidArgument, out);
39168
169+ size_t dim = 0 ;
170+ bool is_fast_path = check_fast_path_conditions (in, indices, &dim);
171+ if (is_fast_path) {
172+ return fast_path (ctx, in, indices, dim, out);
173+ }
174+
175+ ET_KERNEL_CHECK (
176+ ctx, check_index_args (in, indices, out), InvalidArgument, out);
177+
40178 ScalarType in_type = in.scalar_type ();
41179 size_t block_count = count_index_blocks (indices);
42180
0 commit comments