|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
18 | | -#include "vec/functions/array/function_array_cross_product.h" |
| 18 | +#include <gen_cpp/Types_types.h> |
19 | 19 |
|
| 20 | +#include "common/exception.h" |
| 21 | +#include "common/status.h" |
| 22 | +#include "runtime/primitive_type.h" |
| 23 | +#include "vec/columns/column.h" |
| 24 | +#include "vec/columns/column_array.h" |
| 25 | +#include "vec/columns/column_nullable.h" |
| 26 | +#include "vec/common/assert_cast.h" |
| 27 | +#include "vec/core/types.h" |
| 28 | +#include "vec/data_types/data_type.h" |
| 29 | +#include "vec/data_types/data_type_array.h" |
| 30 | +#include "vec/data_types/data_type_nullable.h" |
| 31 | +#include "vec/data_types/data_type_number.h" |
| 32 | +#include "vec/functions/array/function_array_utils.h" |
| 33 | +#include "vec/functions/function.h" |
20 | 34 | #include "vec/functions/simple_function_factory.h" |
| 35 | +#include "vec/utils/util.hpp" |
21 | 36 |
|
22 | 37 | namespace doris::vectorized { |
23 | 38 |
|
| 39 | +class FunctionArrayCrossProduct : public IFunction { |
| 40 | +public: |
| 41 | + using DataType = PrimitiveTypeTraits<TYPE_DOUBLE>::DataType; |
| 42 | + using ColumnType = PrimitiveTypeTraits<TYPE_DOUBLE>::ColumnType; |
| 43 | + |
| 44 | + static constexpr auto name = "cross_product"; |
| 45 | + String get_name() const override { return name; } |
| 46 | + static FunctionPtr create() { return std::make_shared<FunctionArrayCrossProduct>(); } |
| 47 | + size_t get_number_of_arguments() const override { return 2; } |
| 48 | + |
| 49 | + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { |
| 50 | + if (arguments.size() != 2) { |
| 51 | + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, |
| 52 | + "Invalid number of arguments for function {}", get_name()); |
| 53 | + } |
| 54 | + |
| 55 | + if (arguments[0]->get_primitive_type() != TYPE_ARRAY || |
| 56 | + arguments[1]->get_primitive_type() != TYPE_ARRAY) { |
| 57 | + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, |
| 58 | + "Arguments for function {} must be arrays", get_name()); |
| 59 | + } |
| 60 | + |
| 61 | + // return ARRAY<DOUBLE> |
| 62 | + return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>()); |
| 63 | + } |
| 64 | + |
| 65 | + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, |
| 66 | + uint32_t result, size_t input_rows_count) const override { |
| 67 | + const auto& arg1 = block.get_by_position(arguments[0]); |
| 68 | + const auto& arg2 = block.get_by_position(arguments[1]); |
| 69 | + |
| 70 | + const IColumn* col1 = arg1.column.get(); |
| 71 | + const IColumn* col2 = arg2.column.get(); |
| 72 | + |
| 73 | + const ColumnConst* col1_const = nullptr; |
| 74 | + const ColumnConst* col2_const = nullptr; |
| 75 | + |
| 76 | + if (is_column_const(*col1)) { |
| 77 | + col1_const = assert_cast<const ColumnConst*>(col1); |
| 78 | + col1 = &col1_const->get_data_column(); |
| 79 | + } |
| 80 | + if (is_column_const(*col2)) { |
| 81 | + col2_const = assert_cast<const ColumnConst*>(col2); |
| 82 | + col2 = &col2_const->get_data_column(); |
| 83 | + } |
| 84 | + |
| 85 | + const ColumnArray* arr1 = nullptr; |
| 86 | + const ColumnArray* arr2 = nullptr; |
| 87 | + |
| 88 | + if (col1->is_nullable()) { |
| 89 | + auto nullable1 = assert_cast<const ColumnNullable*>(col1); |
| 90 | + arr1 = assert_cast<const ColumnArray*>(nullable1->get_nested_column_ptr().get()); |
| 91 | + } else { |
| 92 | + arr1 = assert_cast<const ColumnArray*>(col1); |
| 93 | + } |
| 94 | + if (col2->is_nullable()) { |
| 95 | + auto nullable2 = assert_cast<const ColumnNullable*>(col2); |
| 96 | + arr2 = assert_cast<const ColumnArray*>(nullable2->get_nested_column_ptr().get()); |
| 97 | + } else { |
| 98 | + arr2 = assert_cast<const ColumnArray*>(col2); |
| 99 | + } |
| 100 | + |
| 101 | + const ColumnFloat64* float1 = nullptr; |
| 102 | + const ColumnFloat64* float2 = nullptr; |
| 103 | + |
| 104 | + if (arr1->get_data_ptr()->is_nullable()) { |
| 105 | + if (arr1->get_data_ptr()->has_null()) { |
| 106 | + return Status::InvalidArgument("First argument for function {} cannot have null elements", |
| 107 | + get_name()); |
| 108 | + } |
| 109 | + auto nullable1 = assert_cast<const ColumnNullable*>(arr1->get_data_ptr().get()); |
| 110 | + float1 = assert_cast<const ColumnFloat64*>(nullable1->get_nested_column_ptr().get()); |
| 111 | + } else { |
| 112 | + float1 = assert_cast<const ColumnFloat64*>(arr1->get_data_ptr().get()); |
| 113 | + } |
| 114 | + |
| 115 | + if (arr2->get_data_ptr()->is_nullable()) { |
| 116 | + if (arr2->get_data_ptr()->has_null()) { |
| 117 | + return Status::InvalidArgument( |
| 118 | + "Second argument for function {} cannot have null elements", |
| 119 | + get_name()); |
| 120 | + } |
| 121 | + auto nullable2 = assert_cast<const ColumnNullable*>(arr2->get_data_ptr().get()); |
| 122 | + float2 = assert_cast<const ColumnFloat64*>(nullable2->get_nested_column_ptr().get()); |
| 123 | + } else { |
| 124 | + float2 = assert_cast<const ColumnFloat64*>(arr2->get_data_ptr().get()); |
| 125 | + } |
| 126 | + |
| 127 | + const auto* offset1 = |
| 128 | + assert_cast<const ColumnArray::ColumnOffsets*>(arr1->get_offsets_ptr().get()); |
| 129 | + const auto* offset2 = |
| 130 | + assert_cast<const ColumnArray::ColumnOffsets*>(arr2->get_offsets_ptr().get()); |
| 131 | + |
| 132 | + // prepare result data |
| 133 | + auto nested_res = ColumnFloat64::create(); |
| 134 | + auto& nested_data = nested_res->get_data(); |
| 135 | + nested_data.resize(3 * input_rows_count); |
| 136 | + |
| 137 | + auto offsets_res = ColumnArray::ColumnOffsets::create(); |
| 138 | + auto& offsets_data = offsets_res->get_data(); |
| 139 | + offsets_data.resize(input_rows_count); |
| 140 | + |
| 141 | + size_t current_offset = 0; |
| 142 | + for (ssize_t row = 0; row < input_rows_count; ++row) { |
| 143 | + ssize_t row1 = col1_const ? 0 : row; |
| 144 | + ssize_t row2 = col2_const ? 0 : row; |
| 145 | + |
| 146 | + ssize_t prev_offset1 = (row1 == 0) ? 0 : offset1->get_data()[row1 - 1]; |
| 147 | + ssize_t prev_offset2 = (row2 == 0) ? 0 : offset2->get_data()[row2 - 1]; |
| 148 | + |
| 149 | + ssize_t size1 = offset1->get_data()[row] - prev_offset1; |
| 150 | + ssize_t size2 = offset2->get_data()[row] - prev_offset2; |
| 151 | + |
| 152 | + if (size1 == 0 || size2 == 0) { |
| 153 | + nested_data[current_offset] = 0; |
| 154 | + nested_data[current_offset + 1] = 0; |
| 155 | + nested_data[current_offset + 2] = 0; |
| 156 | + |
| 157 | + current_offset += 3; |
| 158 | + offsets_data[row] = current_offset; |
| 159 | + continue; |
| 160 | + } |
| 161 | + |
| 162 | + if (size1 != 3 || size2 != 3) { |
| 163 | + return Status::InvalidArgument( |
| 164 | + "function {} requires arrays of size 3", get_name()); |
| 165 | + } |
| 166 | + |
| 167 | + ssize_t base1 = prev_offset1; |
| 168 | + ssize_t base2 = prev_offset2; |
| 169 | + |
| 170 | + double a1 = float1->get_data()[base1]; |
| 171 | + double a2 = float1->get_data()[base1 + 1]; |
| 172 | + double a3 = float1->get_data()[base1 + 2]; |
| 173 | + |
| 174 | + double b1 = float2->get_data()[base2]; |
| 175 | + double b2 = float2->get_data()[base2 + 1]; |
| 176 | + double b3 = float2->get_data()[base2 + 2]; |
| 177 | + |
| 178 | + nested_data[current_offset] = a2 * b3 - a3 * b2; |
| 179 | + nested_data[current_offset + 1] = a3 * b1 - a1 * b3; |
| 180 | + nested_data[current_offset + 2] = a1 * b2 - a2 * b1; |
| 181 | + |
| 182 | + current_offset += 3; |
| 183 | + offsets_data[row] = current_offset; |
| 184 | + } |
| 185 | + |
| 186 | + auto result_col = ColumnArray::create( |
| 187 | + ColumnNullable::create(std::move(nested_res), |
| 188 | + ColumnUInt8::create(nested_res->size(), 0)), |
| 189 | + std::move(offsets_res)); |
| 190 | + |
| 191 | + block.replace_by_position(result, std::move(result_col)); |
| 192 | + return Status::OK(); |
| 193 | + } |
| 194 | +}; |
| 195 | + |
24 | 196 | void register_function_array_cross_product(SimpleFunctionFactory& factory) { |
25 | 197 | factory.register_function<FunctionArrayCrossProduct>(); |
26 | 198 | } |
|
0 commit comments