|
11 | 11 |
|
12 | 12 | #include <executorch/kernels/portable/cpu/util/advanced_index_util.h>
|
13 | 13 | #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
|
| 14 | +#include <executorch/runtime/core/exec_aten/util/tensor_shape_to_c_string.h> |
14 | 15 | #include <executorch/runtime/kernel/kernel_includes.h>
|
15 | 16 |
|
16 | 17 | namespace torch {
|
17 | 18 | namespace executor {
|
18 | 19 | namespace native {
|
19 | 20 |
|
20 | 21 | using Tensor = executorch::aten::Tensor;
|
| 22 | +using TensorOptList = |
| 23 | + executorch::aten::ArrayRef<executorch::aten::optional<Tensor>>; |
21 | 24 |
|
22 | 25 | Tensor& index_put_out(
|
23 | 26 | KernelRuntimeContext& ctx,
|
24 | 27 | const Tensor& in,
|
25 |
| - executorch::aten::ArrayRef<executorch::aten::optional<Tensor>> indices, |
| 28 | + TensorOptList indices, |
26 | 29 | const Tensor& values,
|
27 | 30 | const bool accumulate,
|
28 | 31 | Tensor& out) {
|
@@ -154,6 +157,177 @@ Tensor& index_put_out(
|
154 | 157 | return out;
|
155 | 158 | }
|
156 | 159 |
|
| 160 | +namespace { |
| 161 | + |
| 162 | +bool check_special_case_in_place_args( |
| 163 | + Tensor& in, |
| 164 | + TensorOptList indices, |
| 165 | + const Tensor& values, |
| 166 | + const bool accumulate, |
| 167 | + size_t* dim) { |
| 168 | + ET_CHECK_OR_RETURN_FALSE( |
| 169 | + !accumulate, |
| 170 | + "Special case in-place index_put does not support accumulate"); |
| 171 | + |
| 172 | + ET_CHECK_OR_RETURN_FALSE( |
| 173 | + static_cast<ssize_t>(indices.size()) <= in.dim(), |
| 174 | + "Indexing too many dimensions"); |
| 175 | + |
| 176 | + bool found_index = false; |
| 177 | + for (const auto i : c10::irange(indices.size())) { |
| 178 | + if (indices[i].has_value()) { |
| 179 | + *dim = i; |
| 180 | + ET_CHECK_OR_RETURN_FALSE( |
| 181 | + !found_index, |
| 182 | + "Special case in-place index_put only supports a single non-null index tensor"); |
| 183 | + found_index = true; |
| 184 | + const Tensor& index = indices[i].value(); |
| 185 | + ScalarType ix_type = index.scalar_type(); |
| 186 | + ET_CHECK_OR_RETURN_FALSE( |
| 187 | + ix_type == ScalarType::Long || ix_type == ScalarType::Int, |
| 188 | + "Special case in-place index_put only supports Long or Int index tensors; got %d", |
| 189 | + static_cast<int>(ix_type)); |
| 190 | + ET_CHECK_OR_RETURN_FALSE( |
| 191 | + index.dim() == 1, |
| 192 | + "Special case in-place index_put only supports 1-dimensional index tensors; got %d", |
| 193 | + static_cast<int>(ix_type)); |
| 194 | + } |
| 195 | + } |
| 196 | + |
| 197 | + ET_CHECK_OR_RETURN_FALSE( |
| 198 | + found_index, |
| 199 | + "Special case in-place index_put needs at least one non-null index tensor"); |
| 200 | + |
| 201 | + const Tensor& index = indices[*dim].value(); |
| 202 | + |
| 203 | + bool is_valid_index = true; |
| 204 | + ET_SWITCH_TWO_TYPES( |
| 205 | + Long, Int, index.scalar_type(), ctx, "index_put_", CTYPE, [&]() { |
| 206 | + const CTYPE* const index_arr = index.const_data_ptr<CTYPE>(); |
| 207 | + for (const auto i : c10::irange(index.numel())) { |
| 208 | + if (index_arr[i] < 0 || |
| 209 | + index_arr[i] >= static_cast<CTYPE>(in.size(*dim))) { |
| 210 | + ET_LOG( |
| 211 | + Error, |
| 212 | + "Index %" PRId64 |
| 213 | + " out of range for tensor with size %zd" |
| 214 | + " at dimension %zu", |
| 215 | + static_cast<int64_t>(index_arr[i]), |
| 216 | + in.size(*dim), |
| 217 | + *dim); |
| 218 | + is_valid_index = false; |
| 219 | + break; |
| 220 | + } |
| 221 | + } |
| 222 | + }); |
| 223 | + |
| 224 | + ET_CHECK_OR_RETURN_FALSE( |
| 225 | + is_valid_index, |
| 226 | + "Some index values are not within bounds of input tensor at indexed dim"); |
| 227 | + |
| 228 | + ET_CHECK_OR_RETURN_FALSE( |
| 229 | + values.size(*dim) == index.size(0), |
| 230 | + "Special case in-place index_put requires values to match index length at the indexed dim; values.size(%zu) = %" ET_PRI_TENSOR_SIZE |
| 231 | + ", index_length = %zd", |
| 232 | + *dim, |
| 233 | + values.size(*dim), |
| 234 | + index.size(0)); |
| 235 | + |
| 236 | + Tensor::SizesType expected_values_size[kTensorDimensionLimit] = {}; |
| 237 | + size_t in_ndim = static_cast<size_t>(in.dim()); |
| 238 | + for (const auto i : c10::irange(in_ndim)) { |
| 239 | + if (i != *dim) { |
| 240 | + expected_values_size[i] = static_cast<Tensor::SizesType>(in.size(i)); |
| 241 | + } |
| 242 | + } |
| 243 | + expected_values_size[*dim] = static_cast<Tensor::SizesType>(index.size(0)); |
| 244 | + |
| 245 | +#if ET_LOG_ENABLED |
| 246 | + auto in_shape_str = executorch::runtime::tensor_shape_to_c_string( |
| 247 | + executorch::runtime::Span<const Tensor::SizesType>( |
| 248 | + in.sizes().data(), in.sizes().size())); |
| 249 | + auto values_shape_str = executorch::runtime::tensor_shape_to_c_string( |
| 250 | + executorch::runtime::Span<const Tensor::SizesType>( |
| 251 | + values.sizes().data(), values.sizes().size())); |
| 252 | + |
| 253 | + ET_CHECK_OR_RETURN_FALSE( |
| 254 | + tensor_has_expected_size(values, {expected_values_size, in_ndim}), |
| 255 | + "Special case in-place index_put requires values to match input shape except for indexed dim; got input shape %s and values shape %s", |
| 256 | + in_shape_str.data(), |
| 257 | + values_shape_str.data()); |
| 258 | +#else |
| 259 | + ET_CHECK_OR_RETURN_FALSE( |
| 260 | + tensor_has_expected_size(values, {expected_values_size, in_ndim}), |
| 261 | + "Special case in-place index_put requires values to match input shape except for indexed dim"); |
| 262 | +#endif // ET_LOG_ENABLED |
| 263 | + |
| 264 | + return true; |
| 265 | +} |
| 266 | + |
| 267 | +} // namespace |
| 268 | + |
| 269 | +Tensor& index_put_( |
| 270 | + KernelRuntimeContext& ctx, |
| 271 | + Tensor& in, |
| 272 | + TensorOptList indices, |
| 273 | + const Tensor& values, |
| 274 | + const bool accumulate) { |
| 275 | + (void)ctx; |
| 276 | + |
| 277 | + ET_KERNEL_CHECK( |
| 278 | + ctx, tensors_have_same_dtype(in, values), InvalidArgument, in); |
| 279 | + |
| 280 | + ET_KERNEL_CHECK( |
| 281 | + ctx, tensors_have_same_dim_order(in, values), InvalidArgument, in); |
| 282 | + |
| 283 | + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, in); |
| 284 | + |
| 285 | + size_t dim = 0; |
| 286 | + ET_KERNEL_CHECK( |
| 287 | + ctx, |
| 288 | + check_special_case_in_place_args(in, indices, values, accumulate, &dim), |
| 289 | + InvalidArgument, |
| 290 | + in); |
| 291 | + |
| 292 | + const Tensor& index = indices[dim].value(); |
| 293 | + ScalarType index_type = index.scalar_type(); |
| 294 | + |
| 295 | + if (in.dim() == 0) { |
| 296 | + memcpy(in.mutable_data_ptr(), values.const_data_ptr(), in.nbytes()); |
| 297 | + return in; |
| 298 | + } |
| 299 | + |
| 300 | + size_t leading_dims = getLeadingDims(in, dim); |
| 301 | + size_t trailing_dims = getTrailingDims(in, dim); |
| 302 | + |
| 303 | + if (leading_dims == 0 || trailing_dims == 0) { |
| 304 | + return in; |
| 305 | + } |
| 306 | + |
| 307 | + size_t values_dim_length = values.size(dim); |
| 308 | + size_t in_dim_length = in.size(dim); |
| 309 | + |
| 310 | + size_t length_per_step = trailing_dims * in.element_size(); |
| 311 | + |
| 312 | + const char* values_data = values.const_data_ptr<char>(); |
| 313 | + char* in_data = in.mutable_data_ptr<char>(); |
| 314 | + |
| 315 | + ET_SWITCH_TWO_TYPES(Long, Int, index_type, ctx, "index_put_", CTYPE, [&]() { |
| 316 | + const CTYPE* const index_arr = index.const_data_ptr<CTYPE>(); |
| 317 | + for (const auto i : c10::irange(leading_dims)) { |
| 318 | + const char* src = values_data + i * values_dim_length * length_per_step; |
| 319 | + char* dest = in_data + i * in_dim_length * length_per_step; |
| 320 | + for (const auto j : c10::irange(values_dim_length)) { |
| 321 | + const char* copy_src = src + j * length_per_step; |
| 322 | + char* copy_dest = dest + index_arr[j] * length_per_step; |
| 323 | + memcpy(copy_dest, copy_src, length_per_step); |
| 324 | + } |
| 325 | + } |
| 326 | + }); |
| 327 | + |
| 328 | + return in; |
| 329 | +} |
| 330 | + |
157 | 331 | } // namespace native
|
158 | 332 | } // namespace executor
|
159 | 333 | } // namespace torch
|
0 commit comments