Skip to content

Commit 6b15099

Browse files
authored
fix(query): Fix vector distance function can not accept array nullable types (#18477)
1 parent 71393c7 commit 6b15099

File tree

5 files changed

+797
-116
lines changed

5 files changed

+797
-116
lines changed

src/query/functions/src/scalars/vector.rs

Lines changed: 168 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use databend_common_expression::types::Buffer;
2323
use databend_common_expression::types::DataType;
2424
use databend_common_expression::types::Float32Type;
2525
use databend_common_expression::types::Float64Type;
26+
use databend_common_expression::types::NullableType;
2627
use databend_common_expression::types::NumberColumn;
2728
use databend_common_expression::types::NumberDataType;
2829
use databend_common_expression::types::NumberScalar;
@@ -34,6 +35,7 @@ use databend_common_expression::types::F64;
3435
use databend_common_expression::vectorize_with_builder_1_arg;
3536
use databend_common_expression::vectorize_with_builder_2_arg;
3637
use databend_common_expression::Column;
38+
use databend_common_expression::EvalContext;
3739
use databend_common_expression::Function;
3840
use databend_common_expression::FunctionDomain;
3941
use databend_common_expression::FunctionEval;
@@ -62,20 +64,22 @@ pub fn register(registry: &mut FunctionRegistry) {
6264
|_, _, _| FunctionDomain::MayThrow,
6365
vectorize_with_builder_2_arg::<ArrayType<Float32Type>, ArrayType<Float32Type>, Float32Type>(
6466
|lhs, rhs, output, ctx| {
65-
let l =
66-
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(lhs) };
67-
let r =
68-
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(rhs) };
69-
70-
match cosine_distance(l.as_slice(), r.as_slice()) {
71-
Ok(dist) => {
72-
output.push(F32::from(dist));
73-
}
74-
Err(err) => {
75-
ctx.set_error(output.len(), err.to_string());
76-
output.push(F32::from(0.0));
77-
}
67+
calculate_array_distance(lhs, rhs, output, ctx, cosine_distance);
68+
}
69+
),
70+
);
71+
72+
registry.register_passthrough_nullable_2_arg::<ArrayType<NullableType<Float32Type>>, ArrayType<NullableType<Float32Type>>, Float32Type, _, _>(
73+
"cosine_distance",
74+
|_, _, _| FunctionDomain::MayThrow,
75+
vectorize_with_builder_2_arg::<ArrayType<NullableType<Float32Type>>, ArrayType<NullableType<Float32Type>>, Float32Type>(
76+
|lhs, rhs, output, ctx| {
77+
if lhs.validity.null_count() > 0 || rhs.validity.null_count() > 0 {
78+
ctx.set_error(output.len(), "Vector contain null values");
79+
output.push(F32::from(0.0));
80+
return;
7881
}
82+
calculate_array_distance(lhs.column, rhs.column, output, ctx, cosine_distance);
7983
}
8084
),
8185
);
@@ -85,20 +89,22 @@ pub fn register(registry: &mut FunctionRegistry) {
8589
|_, _, _| FunctionDomain::MayThrow,
8690
vectorize_with_builder_2_arg::<ArrayType<Float32Type>, ArrayType<Float32Type>, Float32Type>(
8791
|lhs, rhs, output, ctx| {
88-
let l =
89-
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(lhs) };
90-
let r =
91-
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(rhs) };
92-
93-
match l1_distance(l.as_slice(), r.as_slice()) {
94-
Ok(dist) => {
95-
output.push(F32::from(dist));
96-
}
97-
Err(err) => {
98-
ctx.set_error(output.len(), err.to_string());
99-
output.push(F32::from(0.0));
100-
}
92+
calculate_array_distance(lhs, rhs, output, ctx, l1_distance);
93+
}
94+
),
95+
);
96+
97+
registry.register_passthrough_nullable_2_arg::<ArrayType<NullableType<Float32Type>>, ArrayType<NullableType<Float32Type>>, Float32Type, _, _>(
98+
"l1_distance",
99+
|_, _, _| FunctionDomain::MayThrow,
100+
vectorize_with_builder_2_arg::<ArrayType<NullableType<Float32Type>>, ArrayType<NullableType<Float32Type>>, Float32Type>(
101+
|lhs, rhs, output, ctx| {
102+
if lhs.validity.null_count() > 0 || rhs.validity.null_count() > 0 {
103+
ctx.set_error(output.len(), "Vector contain null values");
104+
output.push(F32::from(0.0));
105+
return;
101106
}
107+
calculate_array_distance(lhs.column, rhs.column, output, ctx, l1_distance);
102108
}
103109
),
104110
);
@@ -110,20 +116,22 @@ pub fn register(registry: &mut FunctionRegistry) {
110116
|_, _, _| FunctionDomain::MayThrow,
111117
vectorize_with_builder_2_arg::<ArrayType<Float32Type>, ArrayType<Float32Type>, Float32Type>(
112118
|lhs, rhs, output, ctx| {
113-
let l =
114-
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(lhs) };
115-
let r =
116-
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(rhs) };
117-
118-
match l2_distance(l.as_slice(), r.as_slice()) {
119-
Ok(dist) => {
120-
output.push(F32::from(dist));
121-
}
122-
Err(err) => {
123-
ctx.set_error(output.len(), err.to_string());
124-
output.push(F32::from(0.0));
125-
}
119+
calculate_array_distance(lhs, rhs, output, ctx, l2_distance);
120+
}
121+
),
122+
);
123+
124+
registry.register_passthrough_nullable_2_arg::<ArrayType<NullableType<Float32Type>>, ArrayType<NullableType<Float32Type>>, Float32Type, _, _>(
125+
"l2_distance",
126+
|_, _, _| FunctionDomain::MayThrow,
127+
vectorize_with_builder_2_arg::<ArrayType<NullableType<Float32Type>>, ArrayType<NullableType<Float32Type>>, Float32Type>(
128+
|lhs, rhs, output, ctx| {
129+
if lhs.validity.null_count() > 0 || rhs.validity.null_count() > 0 {
130+
ctx.set_error(output.len(), "Vector contain null values");
131+
output.push(F32::from(0.0));
132+
return;
126133
}
134+
calculate_array_distance(lhs.column, rhs.column, output, ctx, l2_distance);
127135
}
128136
),
129137
);
@@ -133,20 +141,22 @@ pub fn register(registry: &mut FunctionRegistry) {
133141
|_, _, _| FunctionDomain::MayThrow,
134142
vectorize_with_builder_2_arg::<ArrayType<Float32Type>, ArrayType<Float32Type>, Float32Type>(
135143
|lhs, rhs, output, ctx| {
136-
let l =
137-
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(lhs) };
138-
let r =
139-
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(rhs) };
140-
141-
match inner_product(l.as_slice(), r.as_slice()) {
142-
Ok(dist) => {
143-
output.push(F32::from(dist));
144-
}
145-
Err(err) => {
146-
ctx.set_error(output.len(), err.to_string());
147-
output.push(F32::from(0.0));
148-
}
144+
calculate_array_distance(lhs, rhs, output, ctx, inner_product);
145+
}
146+
),
147+
);
148+
149+
registry.register_passthrough_nullable_2_arg::<ArrayType<NullableType<Float32Type>>, ArrayType<NullableType<Float32Type>>, Float32Type, _, _>(
150+
"inner_product",
151+
|_, _, _| FunctionDomain::MayThrow,
152+
vectorize_with_builder_2_arg::<ArrayType<NullableType<Float32Type>>, ArrayType<NullableType<Float32Type>>, Float32Type>(
153+
|lhs, rhs, output, ctx| {
154+
if lhs.validity.null_count() > 0 || rhs.validity.null_count() > 0 {
155+
ctx.set_error(output.len(), "Vector contain null values");
156+
output.push(F32::from(0.0));
157+
return;
149158
}
159+
calculate_array_distance(lhs.column, rhs.column, output, ctx, inner_product);
150160
}
151161
),
152162
);
@@ -156,20 +166,22 @@ pub fn register(registry: &mut FunctionRegistry) {
156166
|_, _, _| FunctionDomain::MayThrow,
157167
vectorize_with_builder_2_arg::<ArrayType<Float64Type>, ArrayType<Float64Type>, Float64Type>(
158168
|lhs, rhs, output, ctx| {
159-
let l =
160-
unsafe { std::mem::transmute::<Buffer<F64>, Buffer<f64>>(lhs) };
161-
let r =
162-
unsafe { std::mem::transmute::<Buffer<F64>, Buffer<f64>>(rhs) };
163-
164-
match cosine_distance_64(l.as_slice(), r.as_slice()) {
165-
Ok(dist) => {
166-
output.push(F64::from(dist));
167-
}
168-
Err(err) => {
169-
ctx.set_error(output.len(), err.to_string());
170-
output.push(F64::from(0.0));
171-
}
169+
calculate_array_distance_64(lhs, rhs, output, ctx, cosine_distance_64);
170+
}
171+
),
172+
);
173+
174+
registry.register_passthrough_nullable_2_arg::<ArrayType<NullableType<Float64Type>>, ArrayType<NullableType<Float64Type>>, Float64Type, _, _>(
175+
"cosine_distance",
176+
|_, _, _| FunctionDomain::MayThrow,
177+
vectorize_with_builder_2_arg::<ArrayType<NullableType<Float64Type>>, ArrayType<NullableType<Float64Type>>, Float64Type>(
178+
|lhs, rhs, output, ctx| {
179+
if lhs.validity.null_count() > 0 || rhs.validity.null_count() > 0 {
180+
ctx.set_error(output.len(), "Vector contain null values");
181+
output.push(F64::from(0.0));
182+
return;
172183
}
184+
calculate_array_distance_64(lhs.column, rhs.column, output, ctx, cosine_distance_64);
173185
}
174186
),
175187
);
@@ -179,20 +191,22 @@ pub fn register(registry: &mut FunctionRegistry) {
179191
|_, _, _| FunctionDomain::MayThrow,
180192
vectorize_with_builder_2_arg::<ArrayType<Float64Type>, ArrayType<Float64Type>, Float64Type>(
181193
|lhs, rhs, output, ctx| {
182-
let l =
183-
unsafe { std::mem::transmute::<Buffer<F64>, Buffer<f64>>(lhs) };
184-
let r =
185-
unsafe { std::mem::transmute::<Buffer<F64>, Buffer<f64>>(rhs) };
186-
187-
match l1_distance_64(l.as_slice(), r.as_slice()) {
188-
Ok(dist) => {
189-
output.push(F64::from(dist));
190-
}
191-
Err(err) => {
192-
ctx.set_error(output.len(), err.to_string());
193-
output.push(F64::from(0.0));
194-
}
194+
calculate_array_distance_64(lhs, rhs, output, ctx, l1_distance_64);
195+
}
196+
),
197+
);
198+
199+
registry.register_passthrough_nullable_2_arg::<ArrayType<NullableType<Float64Type>>, ArrayType<NullableType<Float64Type>>, Float64Type, _, _>(
200+
"l1_distance",
201+
|_, _, _| FunctionDomain::MayThrow,
202+
vectorize_with_builder_2_arg::<ArrayType<NullableType<Float64Type>>, ArrayType<NullableType<Float64Type>>, Float64Type>(
203+
|lhs, rhs, output, ctx| {
204+
if lhs.validity.null_count() > 0 || rhs.validity.null_count() > 0 {
205+
ctx.set_error(output.len(), "Vector contain null values");
206+
output.push(F64::from(0.0));
207+
return;
195208
}
209+
calculate_array_distance_64(lhs.column, rhs.column, output, ctx, l1_distance_64);
196210
}
197211
),
198212
);
@@ -202,20 +216,22 @@ pub fn register(registry: &mut FunctionRegistry) {
202216
|_, _, _| FunctionDomain::MayThrow,
203217
vectorize_with_builder_2_arg::<ArrayType<Float64Type>, ArrayType<Float64Type>, Float64Type>(
204218
|lhs, rhs, output, ctx| {
205-
let l =
206-
unsafe { std::mem::transmute::<Buffer<F64>, Buffer<f64>>(lhs) };
207-
let r =
208-
unsafe { std::mem::transmute::<Buffer<F64>, Buffer<f64>>(rhs) };
209-
210-
match l2_distance_64(l.as_slice(), r.as_slice()) {
211-
Ok(dist) => {
212-
output.push(F64::from(dist));
213-
}
214-
Err(err) => {
215-
ctx.set_error(output.len(), err.to_string());
216-
output.push(F64::from(0.0));
217-
}
219+
calculate_array_distance_64(lhs, rhs, output, ctx, l2_distance_64);
220+
}
221+
),
222+
);
223+
224+
registry.register_passthrough_nullable_2_arg::<ArrayType<NullableType<Float64Type>>, ArrayType<NullableType<Float64Type>>, Float64Type, _, _>(
225+
"l2_distance",
226+
|_, _, _| FunctionDomain::MayThrow,
227+
vectorize_with_builder_2_arg::<ArrayType<NullableType<Float64Type>>, ArrayType<NullableType<Float64Type>>, Float64Type>(
228+
|lhs, rhs, output, ctx| {
229+
if lhs.validity.null_count() > 0 || rhs.validity.null_count() > 0 {
230+
ctx.set_error(output.len(), "Vector contain null values");
231+
output.push(F64::from(0.0));
232+
return;
218233
}
234+
calculate_array_distance_64(lhs.column, rhs.column, output, ctx, l2_distance_64);
219235
}
220236
),
221237
);
@@ -225,20 +241,22 @@ pub fn register(registry: &mut FunctionRegistry) {
225241
|_, _, _| FunctionDomain::MayThrow,
226242
vectorize_with_builder_2_arg::<ArrayType<Float64Type>, ArrayType<Float64Type>, Float64Type>(
227243
|lhs, rhs, output, ctx| {
228-
let l =
229-
unsafe { std::mem::transmute::<Buffer<F64>, Buffer<f64>>(lhs) };
230-
let r =
231-
unsafe { std::mem::transmute::<Buffer<F64>, Buffer<f64>>(rhs) };
232-
233-
match inner_product_64(l.as_slice(), r.as_slice()) {
234-
Ok(dist) => {
235-
output.push(F64::from(dist));
236-
}
237-
Err(err) => {
238-
ctx.set_error(output.len(), err.to_string());
239-
output.push(F64::from(0.0));
240-
}
244+
calculate_array_distance_64(lhs, rhs, output, ctx, inner_product_64);
245+
}
246+
),
247+
);
248+
249+
registry.register_passthrough_nullable_2_arg::<ArrayType<NullableType<Float64Type>>, ArrayType<NullableType<Float64Type>>, Float64Type, _, _>(
250+
"inner_product",
251+
|_, _, _| FunctionDomain::MayThrow,
252+
vectorize_with_builder_2_arg::<ArrayType<NullableType<Float64Type>>, ArrayType<NullableType<Float64Type>>, Float64Type>(
253+
|lhs, rhs, output, ctx| {
254+
if lhs.validity.null_count() > 0 || rhs.validity.null_count() > 0 {
255+
ctx.set_error(output.len(), "Vector contain null values");
256+
output.push(F64::from(0.0));
257+
return;
241258
}
259+
calculate_array_distance_64(lhs.column, rhs.column, output, ctx, inner_product_64);
242260
}
243261
),
244262
);
@@ -645,3 +663,49 @@ fn calculate_norm(value: &VectorScalarRef) -> f32 {
645663
}
646664
}
647665
}
666+
667+
fn calculate_array_distance<F>(
668+
lhs: Buffer<F32>,
669+
rhs: Buffer<F32>,
670+
output: &mut Vec<F32>,
671+
ctx: &mut EvalContext,
672+
distance_fn: F,
673+
) where
674+
F: Fn(&[f32], &[f32]) -> Result<f32>,
675+
{
676+
let l = unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(lhs) };
677+
let r = unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(rhs) };
678+
679+
match distance_fn(l.as_slice(), r.as_slice()) {
680+
Ok(dist) => {
681+
output.push(F32::from(dist));
682+
}
683+
Err(err) => {
684+
ctx.set_error(output.len(), err.to_string());
685+
output.push(F32::from(0.0));
686+
}
687+
}
688+
}
689+
690+
fn calculate_array_distance_64<F>(
691+
lhs: Buffer<F64>,
692+
rhs: Buffer<F64>,
693+
output: &mut Vec<F64>,
694+
ctx: &mut EvalContext,
695+
distance_fn: F,
696+
) where
697+
F: Fn(&[f64], &[f64]) -> Result<f64>,
698+
{
699+
let l = unsafe { std::mem::transmute::<Buffer<F64>, Buffer<f64>>(lhs) };
700+
let r = unsafe { std::mem::transmute::<Buffer<F64>, Buffer<f64>>(rhs) };
701+
702+
match distance_fn(l.as_slice(), r.as_slice()) {
703+
Ok(dist) => {
704+
output.push(F64::from(dist));
705+
}
706+
Err(err) => {
707+
ctx.set_error(output.len(), err.to_string());
708+
output.push(F64::from(0.0));
709+
}
710+
}
711+
}

0 commit comments

Comments
 (0)