|
| 1 | +// Copyright (C) 2018-2025 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// |
| 4 | + |
| 5 | +#pragma once |
| 6 | + |
| 7 | +#include <sstream> |
| 8 | +#include <unordered_set> |
| 9 | + |
| 10 | +#include "openvino/op/sparse_fill_empty_rows.hpp" |
| 11 | +#include "utils.hpp" |
| 12 | + |
| 13 | +namespace ov::op::v16 { |
| 14 | +template <class TShape, class TRShape = result_shape_t<TShape>> |
| 15 | +std::vector<TRShape> shape_infer(const SparseFillEmptyRows* op, |
| 16 | + const std::vector<TShape>& input_shapes, |
| 17 | + const ITensorAccessor& tensor_accessor = make_tensor_accessor()) { |
| 18 | + NODE_VALIDATION_CHECK(op, input_shapes.size() == 4); |
| 19 | + |
| 20 | + const auto& values_shape = input_shapes[0]; |
| 21 | + NODE_SHAPE_INFER_CHECK(op, |
| 22 | + input_shapes, |
| 23 | + values_shape.rank().compatible(1), |
| 24 | + "The values input must be a 1D tensor.", |
| 25 | + values_shape); |
| 26 | + |
| 27 | + const auto& dense_shape = input_shapes[1]; |
| 28 | + const bool is_dense_shape_rank_dynamic = dense_shape.rank().is_dynamic(); |
| 29 | + const bool is_dense_shape_valid = |
| 30 | + is_dense_shape_rank_dynamic || (dense_shape.size() == 1 && dense_shape[0].compatible(2)); |
| 31 | + NODE_SHAPE_INFER_CHECK( |
| 32 | + op, |
| 33 | + input_shapes, |
| 34 | + is_dense_shape_valid, |
| 35 | + "The dense_shape input must be 1D and have exactly 2 elements, meaning only 2D sparse tensors are supported."); |
| 36 | + |
| 37 | + const auto& indices_shape = input_shapes[2]; |
| 38 | + const bool is_indices_shape_valid = indices_shape.rank().is_dynamic() || |
| 39 | + (indices_shape.size() == 2 && indices_shape[1].compatible(2) && |
| 40 | + (is_dense_shape_rank_dynamic || indices_shape[0].compatible(values_shape[0]))); |
| 41 | + NODE_SHAPE_INFER_CHECK(op, |
| 42 | + input_shapes, |
| 43 | + is_indices_shape_valid, |
| 44 | + "The indices input must be a 2D tensor with the first dimension matching the size of values " |
| 45 | + "and the second dimension having 2 elements.", |
| 46 | + indices_shape); |
| 47 | + |
| 48 | + const auto& default_value_shape = input_shapes[3]; |
| 49 | + NODE_SHAPE_INFER_CHECK(op, |
| 50 | + input_shapes, |
| 51 | + default_value_shape.rank().compatible(0), |
| 52 | + "The default_value input must be a scalar.", |
| 53 | + default_value_shape); |
| 54 | + |
| 55 | + auto output_shapes = std::vector<TRShape>(3); |
| 56 | + auto& output_indices_shape = output_shapes[0]; |
| 57 | + auto& output_values_shape = output_shapes[1]; |
| 58 | + auto& empty_row_indicator_shape = output_shapes[2]; |
| 59 | + output_indices_shape.resize(2); |
| 60 | + output_values_shape.resize(1); |
| 61 | + empty_row_indicator_shape.resize(1); |
| 62 | + output_indices_shape[1] = 2; // Only 2D cases are supported |
| 63 | + |
| 64 | + if (auto dense_shape_value = get_input_const_data_as_shape<TRShape>(op, 1, tensor_accessor)) { |
| 65 | + const auto& number_of_rows = (*dense_shape_value)[0].get_length(); |
| 66 | + empty_row_indicator_shape[0] = number_of_rows; |
| 67 | + |
| 68 | + if (auto indices_value = get_input_const_data_as<TRShape, int64_t>(op, 2, tensor_accessor)) { |
| 69 | + auto is_valid_index = [](int64_t index, int64_t max_value) -> bool { |
| 70 | + return index >= 0 && index < max_value; |
| 71 | + }; |
| 72 | + auto create_index_error_message = |
| 73 | + [](const std::string& dim_name, int64_t index, int64_t max_value) -> std::string { |
| 74 | + std::stringstream ss; |
| 75 | + ss << "Sparse tensor index out of bounds: " << dim_name << " " << index |
| 76 | + << " is outside the valid range [0, " << (max_value - 1) << "]"; |
| 77 | + return ss.str(); |
| 78 | + }; |
| 79 | + |
| 80 | + // Rows can be referenced multiple times in sparse representation |
| 81 | + std::unordered_set<int64_t> existing_rows; |
| 82 | + const auto& indices_data = *indices_value; |
| 83 | + const auto& number_of_cols = (*dense_shape_value)[1].get_length(); |
| 84 | + for (size_t i = 0, i_next = 1; i_next < indices_data.size(); i += 2, i_next += 2) { |
| 85 | + auto row = indices_data[i]; |
| 86 | + NODE_SHAPE_INFER_CHECK(op, |
| 87 | + input_shapes, |
| 88 | + is_valid_index(row, number_of_rows), |
| 89 | + create_index_error_message("row", row, number_of_rows)); |
| 90 | + |
| 91 | + auto col = indices_data[i_next]; |
| 92 | + NODE_SHAPE_INFER_CHECK(op, |
| 93 | + input_shapes, |
| 94 | + is_valid_index(col, number_of_cols), |
| 95 | + create_index_error_message("column", col, number_of_cols)); |
| 96 | + |
| 97 | + existing_rows.insert(row); |
| 98 | + } |
| 99 | + |
| 100 | + using TDim = typename TRShape::value_type; |
| 101 | + TDim empty_rows_count = number_of_rows - existing_rows.size(); |
| 102 | + output_indices_shape[0] = indices_shape[0] + empty_rows_count; |
| 103 | + output_values_shape[0] = values_shape[0] + empty_rows_count; |
| 104 | + } else { |
| 105 | + output_indices_shape[0] = Dimension::dynamic(); |
| 106 | + output_values_shape[0] = Dimension::dynamic(); |
| 107 | + } |
| 108 | + } else { |
| 109 | + empty_row_indicator_shape[0] = Dimension::dynamic(); |
| 110 | + } |
| 111 | + |
| 112 | + return output_shapes; |
| 113 | +} |
| 114 | +} // namespace ov::op::v16 |
0 commit comments