@@ -22,159 +22,21 @@ namespace native {
2222using Tensor = executorch::aten::Tensor;
2323using TensorOptList = executorch::aten::ArrayRef<std::optional<Tensor>>;
2424
25- namespace {
26-
27- bool check_fast_path_conditions (
28- ET_UNUSED const Tensor& in,
29- TensorOptList indices,
30- size_t * dim) {
31- bool found_index = false ;
32- for (const auto i : c10::irange (indices.size ())) {
33- if (indices[i].has_value ()) {
34- *dim = i;
35- // Fast path only supports a single non-null index tensor
36- if (found_index) {
37- return false ;
38- }
39- found_index = true ;
40- const Tensor& index = indices[i].value ();
41- ScalarType ix_type = index.scalar_type ();
42- // Fast path only supports only supports Long or Int index tensors
43- if (ix_type != ScalarType::Long && ix_type != ScalarType::Int) {
44- return false ;
45- }
46- // Fast path only supports a 1-dimensional index tensor
47- if (index.dim () != 1 ) {
48- return false ;
49- }
50- }
51- }
52-
53- // Fast path only supports needs at least one non-null index tensor
54- if (!found_index) {
55- return false ;
56- }
57-
58- return true ;
59- }
60-
61- bool check_fast_path_args (
62- const Tensor& in,
63- TensorOptList indices,
64- size_t dim,
65- Tensor& out) {
66- ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (in, out));
67-
68- ET_CHECK_OR_RETURN_FALSE (
69- static_cast <ssize_t >(indices.size ()) <= in.dim (),
70- " Indexing too many dimensions" );
71-
72- const Tensor& index = indices[dim].value ();
73-
74- bool is_valid_index = true ;
75- ET_SWITCH_TWO_TYPES (
76- Long, Int, index.scalar_type (), ctx, " index_put_" , CTYPE, [&]() {
77- const CTYPE* const index_arr = index.const_data_ptr <CTYPE>();
78- for (const auto i : c10::irange (index.numel ())) {
79- if (index_arr[i] < 0 ||
80- index_arr[i] >= static_cast <CTYPE>(in.size (dim))) {
81- ET_LOG (
82- Error,
83- " Index %" PRId64
84- " out of range for tensor with size %zd"
85- " at dimension %zu" ,
86- static_cast <int64_t >(index_arr[i]),
87- in.size (dim),
88- dim);
89- is_valid_index = false ;
90- break ;
91- }
92- }
93- });
94-
95- ET_CHECK_OR_RETURN_FALSE (
96- is_valid_index,
97- " Some index values are not within bounds of input tensor at indexed dim" );
98-
99- return true ;
100- }
101-
102- Tensor& fast_path (
25+ Tensor& index_Tensor_out (
10326 KernelRuntimeContext& ctx,
10427 const Tensor& in,
10528 TensorOptList indices,
106- size_t dim,
10729 Tensor& out) {
10830 (void )ctx;
10931
11032 ET_KERNEL_CHECK (
111- ctx, check_fast_path_args (in, indices, dim, out), InvalidArgument, out);
112-
113- const Tensor& index = indices[dim].value ();
114- ScalarType index_type = index.scalar_type ();
115-
116- if (out.dim () == 0 ) {
117- memcpy (out.mutable_data_ptr (), in.const_data_ptr (), out.nbytes ());
118- return out;
119- }
120-
121- size_t leading_dims = getLeadingDims (in, dim);
122- size_t trailing_dims = getTrailingDims (in, dim);
123-
124- if (leading_dims == 0 || trailing_dims == 0 ) {
125- return out;
126- }
127-
128- size_t in_dim_length = in.size (dim);
129- size_t out_dim_length = out.size (dim);
130-
131- size_t length_per_step = trailing_dims * in.element_size ();
132-
133- const char * in_data = in.const_data_ptr <char >();
134- char * out_data = out.mutable_data_ptr <char >();
135-
136- // @lint-ignore CLANGTIDY facebook-hte-CArray
137- static constexpr const char op_name[] = " index.Tensor_out" ;
138-
139- ET_SWITCH_TWO_TYPES (Long, Int, index_type, ctx, op_name, CTYPE, [&]() {
140- const CTYPE* const index_arr = index.const_data_ptr <CTYPE>();
141- for (const auto i : c10::irange (leading_dims)) {
142- const char * src = in_data + i * in_dim_length * length_per_step;
143- char * dest = out_data + i * out_dim_length * length_per_step;
144- for (const auto j : c10::irange (out_dim_length)) {
145- const char * copy_src = src + index_arr[j] * length_per_step;
146- char * copy_dest = dest + j * length_per_step;
147- memcpy (copy_dest, copy_src, length_per_step);
148- }
149- }
150- });
151-
152- return out;
153- }
154-
155- } // namespace
156-
157- Tensor& index_Tensor_out (
158- KernelRuntimeContext& ctx,
159- const Tensor& in,
160- TensorOptList indices,
161- Tensor& out) {
162- (void )ctx;
33+ ctx, check_index_args (in, indices, out), InvalidArgument, out);
16334
16435 ET_KERNEL_CHECK (
16536 ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
16637
16738 ET_KERNEL_CHECK (ctx, tensor_is_default_dim_order (in), InvalidArgument, out);
16839
169- size_t dim = 0 ;
170- bool is_fast_path = check_fast_path_conditions (in, indices, &dim);
171- if (is_fast_path) {
172- return fast_path (ctx, in, indices, dim, out);
173- }
174-
175- ET_KERNEL_CHECK (
176- ctx, check_index_args (in, indices, out), InvalidArgument, out);
177-
17840 ScalarType in_type = in.scalar_type ();
17941 size_t block_count = count_index_blocks (indices);
18042
0 commit comments