Skip to content

Commit c0b84d3

Browse files
committed
fix some error and add table test
1 parent 8be85db commit c0b84d3

File tree

5 files changed

+230
-236
lines changed

5 files changed

+230
-236
lines changed

be/src/vec/functions/array/function_array_cross_product.cpp

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,184 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
#include "vec/functions/array/function_array_cross_product.h"
18+
#include <gen_cpp/Types_types.h>
1919

20+
#include "common/exception.h"
21+
#include "common/status.h"
22+
#include "runtime/primitive_type.h"
23+
#include "vec/columns/column.h"
24+
#include "vec/columns/column_array.h"
25+
#include "vec/columns/column_nullable.h"
26+
#include "vec/common/assert_cast.h"
27+
#include "vec/core/types.h"
28+
#include "vec/data_types/data_type.h"
29+
#include "vec/data_types/data_type_array.h"
30+
#include "vec/data_types/data_type_nullable.h"
31+
#include "vec/data_types/data_type_number.h"
32+
#include "vec/functions/array/function_array_utils.h"
33+
#include "vec/functions/function.h"
2034
#include "vec/functions/simple_function_factory.h"
35+
#include "vec/utils/util.hpp"
2136

2237
namespace doris::vectorized {
2338

39+
class FunctionArrayCrossProduct : public IFunction {
40+
public:
41+
using DataType = PrimitiveTypeTraits<TYPE_DOUBLE>::DataType;
42+
using ColumnType = PrimitiveTypeTraits<TYPE_DOUBLE>::ColumnType;
43+
44+
static constexpr auto name = "cross_product";
45+
String get_name() const override { return name; }
46+
static FunctionPtr create() { return std::make_shared<FunctionArrayCrossProduct>(); }
47+
size_t get_number_of_arguments() const override { return 2; }
48+
49+
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
50+
if (arguments.size() != 2) {
51+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
52+
"Invalid number of arguments for function {}", get_name());
53+
}
54+
55+
if (arguments[0]->get_primitive_type() != TYPE_ARRAY ||
56+
arguments[1]->get_primitive_type() != TYPE_ARRAY) {
57+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
58+
"Arguments for function {} must be arrays", get_name());
59+
}
60+
61+
// return ARRAY<DOUBLE>
62+
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
63+
}
64+
65+
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
66+
uint32_t result, size_t input_rows_count) const override {
67+
const auto& arg1 = block.get_by_position(arguments[0]);
68+
const auto& arg2 = block.get_by_position(arguments[1]);
69+
70+
const IColumn* col1 = arg1.column.get();
71+
const IColumn* col2 = arg2.column.get();
72+
73+
const ColumnConst* col1_const = nullptr;
74+
const ColumnConst* col2_const = nullptr;
75+
76+
if (is_column_const(*col1)) {
77+
col1_const = assert_cast<const ColumnConst*>(col1);
78+
col1 = &col1_const->get_data_column();
79+
}
80+
if (is_column_const(*col2)) {
81+
col2_const = assert_cast<const ColumnConst*>(col2);
82+
col2 = &col2_const->get_data_column();
83+
}
84+
85+
const ColumnArray* arr1 = nullptr;
86+
const ColumnArray* arr2 = nullptr;
87+
88+
if (col1->is_nullable()) {
89+
auto nullable1 = assert_cast<const ColumnNullable*>(col1);
90+
arr1 = assert_cast<const ColumnArray*>(nullable1->get_nested_column_ptr().get());
91+
} else {
92+
arr1 = assert_cast<const ColumnArray*>(col1);
93+
}
94+
if (col2->is_nullable()) {
95+
auto nullable2 = assert_cast<const ColumnNullable*>(col2);
96+
arr2 = assert_cast<const ColumnArray*>(nullable2->get_nested_column_ptr().get());
97+
} else {
98+
arr2 = assert_cast<const ColumnArray*>(col2);
99+
}
100+
101+
const ColumnFloat64* float1 = nullptr;
102+
const ColumnFloat64* float2 = nullptr;
103+
104+
if (arr1->get_data_ptr()->is_nullable()) {
105+
if (arr1->get_data_ptr()->has_null()) {
106+
return Status::InvalidArgument("First argument for function {} cannot have null elements",
107+
get_name());
108+
}
109+
auto nullable1 = assert_cast<const ColumnNullable*>(arr1->get_data_ptr().get());
110+
float1 = assert_cast<const ColumnFloat64*>(nullable1->get_nested_column_ptr().get());
111+
} else {
112+
float1 = assert_cast<const ColumnFloat64*>(arr1->get_data_ptr().get());
113+
}
114+
115+
if (arr2->get_data_ptr()->is_nullable()) {
116+
if (arr2->get_data_ptr()->has_null()) {
117+
return Status::InvalidArgument(
118+
"Second argument for function {} cannot have null elements",
119+
get_name());
120+
}
121+
auto nullable2 = assert_cast<const ColumnNullable*>(arr2->get_data_ptr().get());
122+
float2 = assert_cast<const ColumnFloat64*>(nullable2->get_nested_column_ptr().get());
123+
} else {
124+
float2 = assert_cast<const ColumnFloat64*>(arr2->get_data_ptr().get());
125+
}
126+
127+
const auto* offset1 =
128+
assert_cast<const ColumnArray::ColumnOffsets*>(arr1->get_offsets_ptr().get());
129+
const auto* offset2 =
130+
assert_cast<const ColumnArray::ColumnOffsets*>(arr2->get_offsets_ptr().get());
131+
132+
// prepare result data
133+
auto nested_res = ColumnFloat64::create();
134+
auto& nested_data = nested_res->get_data();
135+
nested_data.resize(3 * input_rows_count);
136+
137+
auto offsets_res = ColumnArray::ColumnOffsets::create();
138+
auto& offsets_data = offsets_res->get_data();
139+
offsets_data.resize(input_rows_count);
140+
141+
size_t current_offset = 0;
142+
for (ssize_t row = 0; row < input_rows_count; ++row) {
143+
ssize_t row1 = col1_const ? 0 : row;
144+
ssize_t row2 = col2_const ? 0 : row;
145+
146+
ssize_t prev_offset1 = (row1 == 0) ? 0 : offset1->get_data()[row1 - 1];
147+
ssize_t prev_offset2 = (row2 == 0) ? 0 : offset2->get_data()[row2 - 1];
148+
149+
ssize_t size1 = offset1->get_data()[row] - prev_offset1;
150+
ssize_t size2 = offset2->get_data()[row] - prev_offset2;
151+
152+
if (size1 == 0 || size2 == 0) {
153+
nested_data[current_offset] = 0;
154+
nested_data[current_offset + 1] = 0;
155+
nested_data[current_offset + 2] = 0;
156+
157+
current_offset += 3;
158+
offsets_data[row] = current_offset;
159+
continue;
160+
}
161+
162+
if (size1 != 3 || size2 != 3) {
163+
return Status::InvalidArgument(
164+
"function {} requires arrays of size 3", get_name());
165+
}
166+
167+
ssize_t base1 = prev_offset1;
168+
ssize_t base2 = prev_offset2;
169+
170+
double a1 = float1->get_data()[base1];
171+
double a2 = float1->get_data()[base1 + 1];
172+
double a3 = float1->get_data()[base1 + 2];
173+
174+
double b1 = float2->get_data()[base2];
175+
double b2 = float2->get_data()[base2 + 1];
176+
double b3 = float2->get_data()[base2 + 2];
177+
178+
nested_data[current_offset] = a2 * b3 - a3 * b2;
179+
nested_data[current_offset + 1] = a3 * b1 - a1 * b3;
180+
nested_data[current_offset + 2] = a1 * b2 - a2 * b1;
181+
182+
current_offset += 3;
183+
offsets_data[row] = current_offset;
184+
}
185+
186+
auto result_col = ColumnArray::create(
187+
ColumnNullable::create(std::move(nested_res),
188+
ColumnUInt8::create(nested_res->size(), 0)),
189+
std::move(offsets_res));
190+
191+
block.replace_by_position(result, std::move(result_col));
192+
return Status::OK();
193+
}
194+
};
195+
24196
void register_function_array_cross_product(SimpleFunctionFactory& factory) {
25197
factory.register_function<FunctionArrayCrossProduct>();
26198
}

be/src/vec/functions/array/function_array_cross_product.h

Lines changed: 0 additions & 204 deletions
This file was deleted.

0 commit comments

Comments
 (0)