1+ // Licensed to the Apache Software Foundation (ASF) under one
2+ // or more contributor license agreements. See the NOTICE file
3+ // distributed with this work for additional information
4+ // regarding copyright ownership. The ASF licenses this file
5+ // to you under the Apache License, Version 2.0 (the
6+ // "License"); you may not use this file except in compliance
7+ // with the License. You may obtain a copy of the License at
8+ //
9+ // http://www.apache.org/licenses/LICENSE-2.0
10+ //
11+ // Unless required by applicable law or agreed to in writing,
12+ // software distributed under the License is distributed on an
13+ // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+ // KIND, either express or implied. See the License for the
15+ // specific language governing permissions and limitations
16+ // under the License.
17+
18+ #pragma once
19+
20+ #include < gen_cpp/Types_types.h>
21+
22+ #include " common/exception.h"
23+ #include " common/status.h"
24+ #include " runtime/primitive_type.h"
25+ #include " vec/columns/column.h"
26+ #include " vec/columns/column_array.h"
27+ #include " vec/columns/column_nullable.h"
28+ #include " vec/common/assert_cast.h"
29+ #include " vec/core/types.h"
30+ #include " vec/data_types/data_type.h"
31+ #include " vec/data_types/data_type_array.h"
32+ #include " vec/data_types/data_type_nullable.h"
33+ #include " vec/data_types/data_type_number.h"
34+ #include " vec/functions/array/function_array_utils.h"
35+ #include " vec/functions/function.h"
36+ #include " vec/utils/util.hpp"
37+
38+ namespace doris ::vectorized {
39+
40+ class FunctionArrayCrossProduct : public IFunction {
41+ public:
42+ using DataType = PrimitiveTypeTraits<TYPE_FLOAT>::DataType;
43+ using ColumnType = PrimitiveTypeTraits<TYPE_FLOAT>::ColumnType;
44+
45+ static constexpr auto name = " cross_product" ;
46+ String get_name () const override { return name; }
47+ static FunctionPtr create () { return std::make_shared<FunctionArrayCrossProduct>(); }
48+ size_t get_number_of_arguments () const override { return 2 ; }
49+
50+ DataTypePtr get_return_type_impl (const DataTypes& arguments) const override {
51+ if (arguments.size () != 2 ) {
52+ throw doris::Exception (ErrorCode::INVALID_ARGUMENT,
53+ " Invalid number of arguments for function {}" , get_name ());
54+ }
55+
56+ if (arguments[0 ]->get_primitive_type () != TYPE_ARRAY ||
57+ arguments[1 ]->get_primitive_type () != TYPE_ARRAY) {
58+ throw doris::Exception (ErrorCode::INVALID_ARGUMENT,
59+ " Arguments for function {} must be arrays" , get_name ());
60+ }
61+
62+ // return ARRAY<FLOAT>
63+ return std::make_shared<DataTypeArray>(
64+ std::make_shared<DataTypeFloat32>());
65+ }
66+
67+ // strict semantics: do not allow NULL
68+ bool use_default_implementation_for_nulls () const override { return false ; }
69+
70+ Status execute_impl (FunctionContext* context, Block& block, const ColumnNumbers& arguments,
71+ uint32_t result, size_t input_rows_count) const override {
72+ const auto & arg1 = block.get_by_position (arguments[0 ]);
73+ const auto & arg2 = block.get_by_position (arguments[1 ]);
74+
75+ auto col1 = arg1.column ->convert_to_full_column_if_const ();
76+ auto col2 = arg2.column ->convert_to_full_column_if_const ();
77+
78+ if (col1->size () != col2->size ()) {
79+ return Status::RuntimeError (
80+ fmt::format (" function {} have different input array sizes: {} and {}" ,
81+ get_name (), col1->size (), col2->size ()));
82+ }
83+
84+ const ColumnArray* arr1 = nullptr ;
85+ const ColumnArray* arr2 = nullptr ;
86+
87+ if (const auto * nullable =
88+ typeid_cast<const ColumnNullable*>(col1.get ())) {
89+ if (nullable->has_null ()) {
90+ throw doris::Exception (ErrorCode::INVALID_ARGUMENT,
91+ " First argument for function {} cannot be null" , get_name ());
92+ }
93+ arr1 = assert_cast<const ColumnArray*>(nullable->get_nested_column_ptr ().get ());
94+ } else {
95+ arr1 = assert_cast<const ColumnArray*>(col1.get ());
96+ }
97+
98+ if (const auto * nullable =
99+ typeid_cast<const ColumnNullable*>(col2.get ())) {
100+ if (nullable->has_null ()) {
101+ throw doris::Exception (ErrorCode::INVALID_ARGUMENT,
102+ " Second argument for function {} cannot be null" ,
103+ get_name ());
104+ }
105+ arr2 = assert_cast<const ColumnArray*>(nullable->get_nested_column_ptr ().get ());
106+ } else {
107+ arr2 = assert_cast<const ColumnArray*>(col2.get ());
108+ }
109+
110+ const ColumnFloat32* float1 = nullptr ;
111+ const ColumnFloat32* float2 = nullptr ;
112+
113+ if (const auto * nullable =
114+ typeid_cast<const ColumnNullable*>(arr1->get_data_ptr ().get ())) {
115+ if (nullable->has_null ()) {
116+ throw doris::Exception (ErrorCode::INVALID_ARGUMENT,
117+ " First argument for function {} cannot have null elements" ,
118+ get_name ());
119+ }
120+ float1 = assert_cast<const ColumnFloat32*>(nullable->get_nested_column_ptr ().get ());
121+ } else {
122+ float1 = assert_cast<const ColumnFloat32*>(arr1->get_data_ptr ().get ());
123+ }
124+
125+ if (const auto * nullable =
126+ typeid_cast<const ColumnNullable*>(arr2->get_data_ptr ().get ())) {
127+ if (nullable->has_null ()) {
128+ throw doris::Exception (ErrorCode::INVALID_ARGUMENT,
129+ " Second argument for function {} cannot have null elements" ,
130+ get_name ());
131+ }
132+ float2 = assert_cast<const ColumnFloat32*>(nullable->get_nested_column_ptr ().get ());
133+ } else {
134+ float2 = assert_cast<const ColumnFloat32*>(arr2->get_data_ptr ().get ());
135+ }
136+
137+ const auto * offset1 =
138+ assert_cast<const ColumnArray::ColumnOffsets*>(arr1->get_offsets_ptr ().get ());
139+ const auto * offset2 =
140+ assert_cast<const ColumnArray::ColumnOffsets*>(arr2->get_offsets_ptr ().get ());
141+
142+ // prepare result data
143+ auto nested_res = ColumnFloat32::create ();
144+ auto offsets_res = ColumnArray::ColumnOffsets::create ();
145+ auto & offsets_data = offsets_res->get_data ();
146+ offsets_data.reserve (input_rows_count);
147+ size_t current_offset = 0 ;
148+
149+ size_t row_cnt = offset1->size ();
150+ size_t prev_offset1 = 0 ;
151+ size_t prev_offset2 = 0 ;
152+ for (ssize_t row = 0 ; row < row_cnt; ++row) {
153+ ssize_t size1 = offset1->get_data ()[row] - prev_offset1;
154+ ssize_t size2 = offset2->get_data ()[row] - prev_offset2;
155+
156+ if (size1 != size2) [[unlikely]] {
157+ return Status::InvalidArgument (
158+ " function {} have different input element sizes of array: {} and {}" ,
159+ get_name (), size1, size2);
160+ }
161+
162+ if (size1 != 3 || size2 != 3 ) {
163+ throw doris::Exception (ErrorCode::INVALID_ARGUMENT,
164+ " function {} requires arrays of size 3" ,
165+ get_name ());
166+ }
167+
168+ ssize_t base1 = prev_offset1;
169+ ssize_t base2 = prev_offset2;
170+
171+ float a1 = float1->get_data ()[base1];
172+ float a2 = float1->get_data ()[base1 + 1 ];
173+ float a3 = float1->get_data ()[base1 + 2 ];
174+
175+ float b1 = float2->get_data ()[base2];
176+ float b2 = float2->get_data ()[base2 + 1 ];
177+ float b3 = float2->get_data ()[base2 + 2 ];
178+
179+ float c1 = a2 * b3 - a3 * b2;
180+ float c2 = a3 * b1 - a1 * b3;
181+ float c3 = a1 * b2 - a2 * b1;
182+
183+ nested_res->insert_value (c1);
184+ nested_res->insert_value (c2);
185+ nested_res->insert_value (c3);
186+
187+ current_offset += 3 ;
188+ offsets_data.push_back (current_offset);
189+
190+ prev_offset1 = offset1->get_data ()[row];
191+ prev_offset2 = offset2->get_data ()[row];
192+ }
193+
194+ auto result_col = ColumnArray::create (
195+ ColumnNullable::create (std::move (nested_res),
196+ ColumnUInt8::create (nested_res->size (), 0 )),
197+ std::move (offsets_res));
198+
199+ block.replace_by_position (result, std::move (result_col));
200+ return Status::OK ();
201+ }
202+ };
203+
204+ } // namespace doris::vectorized
0 commit comments