Skip to content

Commit 8be85db

Browse files
committed
[Feature](function) Support function cross_product of DuckDB
1 parent 28be8c7 commit 8be85db

File tree

8 files changed

+393
-0
lines changed

8 files changed

+393
-0
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "vec/functions/array/function_array_cross_product.h"
19+
20+
#include "vec/functions/simple_function_factory.h"
21+
22+
namespace doris::vectorized {
23+
24+
void register_function_array_cross_product(SimpleFunctionFactory& factory) {
25+
factory.register_function<FunctionArrayCrossProduct>();
26+
}
27+
28+
} // namespace doris::vectorized
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#pragma once
19+
20+
#include <gen_cpp/Types_types.h>
21+
22+
#include "common/exception.h"
23+
#include "common/status.h"
24+
#include "runtime/primitive_type.h"
25+
#include "vec/columns/column.h"
26+
#include "vec/columns/column_array.h"
27+
#include "vec/columns/column_nullable.h"
28+
#include "vec/common/assert_cast.h"
29+
#include "vec/core/types.h"
30+
#include "vec/data_types/data_type.h"
31+
#include "vec/data_types/data_type_array.h"
32+
#include "vec/data_types/data_type_nullable.h"
33+
#include "vec/data_types/data_type_number.h"
34+
#include "vec/functions/array/function_array_utils.h"
35+
#include "vec/functions/function.h"
36+
#include "vec/utils/util.hpp"
37+
38+
namespace doris::vectorized {
39+
40+
class FunctionArrayCrossProduct : public IFunction {
41+
public:
42+
using DataType = PrimitiveTypeTraits<TYPE_FLOAT>::DataType;
43+
using ColumnType = PrimitiveTypeTraits<TYPE_FLOAT>::ColumnType;
44+
45+
static constexpr auto name = "cross_product";
46+
String get_name() const override { return name; }
47+
static FunctionPtr create() { return std::make_shared<FunctionArrayCrossProduct>(); }
48+
size_t get_number_of_arguments() const override { return 2; }
49+
50+
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
51+
if (arguments.size() != 2) {
52+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
53+
"Invalid number of arguments for function {}", get_name());
54+
}
55+
56+
if (arguments[0]->get_primitive_type() != TYPE_ARRAY ||
57+
arguments[1]->get_primitive_type() != TYPE_ARRAY) {
58+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
59+
"Arguments for function {} must be arrays", get_name());
60+
}
61+
62+
// return ARRAY<FLOAT>
63+
return std::make_shared<DataTypeArray>(
64+
std::make_shared<DataTypeFloat32>());
65+
}
66+
67+
// strict semantics: do not allow NULL
68+
bool use_default_implementation_for_nulls() const override { return false; }
69+
70+
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
71+
uint32_t result, size_t input_rows_count) const override {
72+
const auto& arg1 = block.get_by_position(arguments[0]);
73+
const auto& arg2 = block.get_by_position(arguments[1]);
74+
75+
auto col1 = arg1.column->convert_to_full_column_if_const();
76+
auto col2 = arg2.column->convert_to_full_column_if_const();
77+
78+
if (col1->size() != col2->size()) {
79+
return Status::RuntimeError(
80+
fmt::format("function {} have different input array sizes: {} and {}",
81+
get_name(), col1->size(), col2->size()));
82+
}
83+
84+
const ColumnArray* arr1 = nullptr;
85+
const ColumnArray* arr2 = nullptr;
86+
87+
if (const auto* nullable =
88+
typeid_cast<const ColumnNullable*>(col1.get())) {
89+
if (nullable->has_null()) {
90+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
91+
"First argument for function {} cannot be null", get_name());
92+
}
93+
arr1 = assert_cast<const ColumnArray*>(nullable->get_nested_column_ptr().get());
94+
} else {
95+
arr1 = assert_cast<const ColumnArray*>(col1.get());
96+
}
97+
98+
if (const auto* nullable =
99+
typeid_cast<const ColumnNullable*>(col2.get())) {
100+
if (nullable->has_null()) {
101+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
102+
"Second argument for function {} cannot be null",
103+
get_name());
104+
}
105+
arr2 = assert_cast<const ColumnArray*>(nullable->get_nested_column_ptr().get());
106+
} else {
107+
arr2 = assert_cast<const ColumnArray*>(col2.get());
108+
}
109+
110+
const ColumnFloat32* float1 = nullptr;
111+
const ColumnFloat32* float2 = nullptr;
112+
113+
if (const auto* nullable =
114+
typeid_cast<const ColumnNullable*>(arr1->get_data_ptr().get())) {
115+
if (nullable->has_null()) {
116+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
117+
"First argument for function {} cannot have null elements",
118+
get_name());
119+
}
120+
float1 = assert_cast<const ColumnFloat32*>(nullable->get_nested_column_ptr().get());
121+
} else {
122+
float1 = assert_cast<const ColumnFloat32*>(arr1->get_data_ptr().get());
123+
}
124+
125+
if (const auto* nullable =
126+
typeid_cast<const ColumnNullable*>(arr2->get_data_ptr().get())) {
127+
if (nullable->has_null()) {
128+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
129+
"Second argument for function {} cannot have null elements",
130+
get_name());
131+
}
132+
float2 = assert_cast<const ColumnFloat32*>(nullable->get_nested_column_ptr().get());
133+
} else {
134+
float2 = assert_cast<const ColumnFloat32*>(arr2->get_data_ptr().get());
135+
}
136+
137+
const auto* offset1 =
138+
assert_cast<const ColumnArray::ColumnOffsets*>(arr1->get_offsets_ptr().get());
139+
const auto* offset2 =
140+
assert_cast<const ColumnArray::ColumnOffsets*>(arr2->get_offsets_ptr().get());
141+
142+
// prepare result data
143+
auto nested_res = ColumnFloat32::create();
144+
auto offsets_res = ColumnArray::ColumnOffsets::create();
145+
auto& offsets_data = offsets_res->get_data();
146+
offsets_data.reserve(input_rows_count);
147+
size_t current_offset = 0;
148+
149+
size_t row_cnt = offset1->size();
150+
size_t prev_offset1 = 0;
151+
size_t prev_offset2 = 0;
152+
for (ssize_t row = 0; row < row_cnt; ++row) {
153+
ssize_t size1 = offset1->get_data()[row] - prev_offset1;
154+
ssize_t size2 = offset2->get_data()[row] - prev_offset2;
155+
156+
if (size1 != size2) [[unlikely]] {
157+
return Status::InvalidArgument(
158+
"function {} have different input element sizes of array: {} and {}",
159+
get_name(), size1, size2);
160+
}
161+
162+
if (size1 != 3 || size2 != 3) {
163+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
164+
"function {} requires arrays of size 3",
165+
get_name());
166+
}
167+
168+
ssize_t base1 = prev_offset1;
169+
ssize_t base2 = prev_offset2;
170+
171+
float a1 = float1->get_data()[base1];
172+
float a2 = float1->get_data()[base1 + 1];
173+
float a3 = float1->get_data()[base1 + 2];
174+
175+
float b1 = float2->get_data()[base2];
176+
float b2 = float2->get_data()[base2 + 1];
177+
float b3 = float2->get_data()[base2 + 2];
178+
179+
float c1 = a2 * b3 - a3 * b2;
180+
float c2 = a3 * b1 - a1 * b3;
181+
float c3 = a1 * b2 - a2 * b1;
182+
183+
nested_res->insert_value(c1);
184+
nested_res->insert_value(c2);
185+
nested_res->insert_value(c3);
186+
187+
current_offset += 3;
188+
offsets_data.push_back(current_offset);
189+
190+
prev_offset1 = offset1->get_data()[row];
191+
prev_offset2 = offset2->get_data()[row];
192+
}
193+
194+
auto result_col = ColumnArray::create(
195+
ColumnNullable::create(std::move(nested_res),
196+
ColumnUInt8::create(nested_res->size(), 0)),
197+
std::move(offsets_res));
198+
199+
block.replace_by_position(result, std::move(result_col));
200+
return Status::OK();
201+
}
202+
};
203+
204+
} // namespace doris::vectorized

be/src/vec/functions/array/function_array_register.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ void register_function_array_pushback(SimpleFunctionFactory& factory);
5353
void register_function_array_first_or_last_index(SimpleFunctionFactory& factory);
5454
void register_function_array_cum_sum(SimpleFunctionFactory& factory);
5555
void register_function_array_count(SimpleFunctionFactory&);
56+
void register_function_array_cross_product(SimpleFunctionFactory& factory);
5657
void register_function_array_filter_function(SimpleFunctionFactory&);
5758
void register_function_array_splits(SimpleFunctionFactory&);
5859
void register_function_array_contains_all(SimpleFunctionFactory&);
@@ -91,6 +92,7 @@ void register_function_array(SimpleFunctionFactory& factory) {
9192
register_function_array_first_or_last_index(factory);
9293
register_function_array_cum_sum(factory);
9394
register_function_array_count(factory);
95+
register_function_array_cross_product(factory);
9496
register_function_array_filter_function(factory);
9597
register_function_array_splits(factory);
9698
register_function_array_contains_all(factory);

fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateMap;
154154
import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateNamedStruct;
155155
import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateStruct;
156+
import org.apache.doris.nereids.trees.expressions.functions.scalar.CrossProduct;
156157
import org.apache.doris.nereids.trees.expressions.functions.scalar.Csc;
157158
import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentCatalog;
158159
import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentDate;
@@ -695,6 +696,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
695696
scalar(CreateMap.class, "map"),
696697
scalar(CreateStruct.class, "struct"),
697698
scalar(CreateNamedStruct.class, "named_struct"),
699+
scalar(CrossProduct.class, "cross_product"),
698700
scalar(CurrentCatalog.class, "current_catalog"),
699701
scalar(CurrentDate.class, "curdate", "current_date"),
700702
scalar(CurrentTime.class, "curtime", "current_time"),
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.trees.expressions.functions.scalar;
19+
20+
import org.apache.doris.catalog.FunctionSignature;
21+
import org.apache.doris.nereids.trees.expressions.Expression;
22+
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
23+
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
24+
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
25+
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
26+
import org.apache.doris.nereids.types.ArrayType;
27+
import org.apache.doris.nereids.types.FloatType;
28+
29+
import com.google.common.base.Preconditions;
30+
import com.google.common.collect.ImmutableList;
31+
32+
import java.util.List;
33+
34+
/**
35+
* cosine_distance function
36+
*/
37+
public class CrossProduct extends ScalarFunction implements ExplicitlyCastableSignature,
38+
BinaryExpression, AlwaysNotNullable {
39+
40+
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
41+
FunctionSignature.ret(ArrayType.of(FloatType.INSTANCE))
42+
.args(ArrayType.of(FloatType.INSTANCE), ArrayType.of(FloatType.INSTANCE))
43+
);
44+
45+
/**
46+
* constructor with 1 argument.
47+
*/
48+
public CrossProduct(Expression arg0, Expression arg1) {
49+
super("cross_product", arg0, arg1);
50+
}
51+
52+
/** constructor for withChildren and reuse signature */
53+
private CrossProduct(ScalarFunctionParams functionParams) {
54+
super(functionParams);
55+
}
56+
57+
/**
58+
* withChildren.
59+
*/
60+
@Override
61+
public CrossProduct withChildren(List<Expression> children) {
62+
Preconditions.checkArgument(children.size() == 2);
63+
return new CrossProduct(getFunctionParams(children));
64+
}
65+
66+
@Override
67+
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
68+
return visitor.visitCrossProduct(this, context);
69+
}
70+
71+
@Override
72+
public List<FunctionSignature> getSignatures() {
73+
return SIGNATURES;
74+
}
75+
}

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@
164164
import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateMap;
165165
import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateNamedStruct;
166166
import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateStruct;
167+
import org.apache.doris.nereids.trees.expressions.functions.scalar.CrossProduct;
167168
import org.apache.doris.nereids.trees.expressions.functions.scalar.Csc;
168169
import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentCatalog;
169170
import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentDate;
@@ -1081,6 +1082,10 @@ default R visitCountSubstring(CountSubstring countSubstring, C context) {
10811082
return visitScalarFunction(countSubstring, context);
10821083
}
10831084

1085+
default R visitCrossProduct(CrossProduct crossProduct, C context) {
1086+
return visitScalarFunction(crossProduct, context);
1087+
}
1088+
10841089
default R visitCurrentCatalog(CurrentCatalog currentCatalog, C context) {
10851090
return visitScalarFunction(currentCatalog, context);
10861091
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
-- This file is automatically generated. You should know what you did if you want to edit this
2+
-- !sql --
3+
[-1, 2, -1]
4+
5+
-- !sql --
6+
[0, 0, 0]
7+
8+
-- !sql --
9+
[0, 0, 0]
10+
11+
-- !sql --
12+
[0, 0, 1]
13+
14+
-- !sql --
15+
[0, 0, -1]

0 commit comments

Comments
 (0)