@@ -23,6 +23,7 @@ use databend_common_expression::types::Buffer;
23
23
use databend_common_expression:: types:: DataType ;
24
24
use databend_common_expression:: types:: Float32Type ;
25
25
use databend_common_expression:: types:: Float64Type ;
26
+ use databend_common_expression:: types:: NullableType ;
26
27
use databend_common_expression:: types:: NumberColumn ;
27
28
use databend_common_expression:: types:: NumberDataType ;
28
29
use databend_common_expression:: types:: NumberScalar ;
@@ -34,6 +35,7 @@ use databend_common_expression::types::F64;
34
35
use databend_common_expression:: vectorize_with_builder_1_arg;
35
36
use databend_common_expression:: vectorize_with_builder_2_arg;
36
37
use databend_common_expression:: Column ;
38
+ use databend_common_expression:: EvalContext ;
37
39
use databend_common_expression:: Function ;
38
40
use databend_common_expression:: FunctionDomain ;
39
41
use databend_common_expression:: FunctionEval ;
@@ -62,20 +64,22 @@ pub fn register(registry: &mut FunctionRegistry) {
62
64
|_, _, _| FunctionDomain :: MayThrow ,
63
65
vectorize_with_builder_2_arg :: < ArrayType < Float32Type > , ArrayType < Float32Type > , Float32Type > (
64
66
|lhs, rhs, output, ctx| {
65
- let l =
66
- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( lhs) } ;
67
- let r =
68
- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( rhs) } ;
69
-
70
- match cosine_distance ( l. as_slice ( ) , r. as_slice ( ) ) {
71
- Ok ( dist) => {
72
- output. push ( F32 :: from ( dist) ) ;
73
- }
74
- Err ( err) => {
75
- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
76
- output. push ( F32 :: from ( 0.0 ) ) ;
77
- }
67
+ calculate_array_distance ( lhs, rhs, output, ctx, cosine_distance) ;
68
+ }
69
+ ) ,
70
+ ) ;
71
+
72
+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type , _ , _ > (
73
+ "cosine_distance" ,
74
+ |_, _, _| FunctionDomain :: MayThrow ,
75
+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type > (
76
+ |lhs, rhs, output, ctx| {
77
+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
78
+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
79
+ output. push ( F32 :: from ( 0.0 ) ) ;
80
+ return ;
78
81
}
82
+ calculate_array_distance ( lhs. column , rhs. column , output, ctx, cosine_distance) ;
79
83
}
80
84
) ,
81
85
) ;
@@ -85,20 +89,22 @@ pub fn register(registry: &mut FunctionRegistry) {
85
89
|_, _, _| FunctionDomain :: MayThrow ,
86
90
vectorize_with_builder_2_arg :: < ArrayType < Float32Type > , ArrayType < Float32Type > , Float32Type > (
87
91
|lhs, rhs, output, ctx| {
88
- let l =
89
- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( lhs) } ;
90
- let r =
91
- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( rhs) } ;
92
-
93
- match l1_distance ( l. as_slice ( ) , r. as_slice ( ) ) {
94
- Ok ( dist) => {
95
- output. push ( F32 :: from ( dist) ) ;
96
- }
97
- Err ( err) => {
98
- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
99
- output. push ( F32 :: from ( 0.0 ) ) ;
100
- }
92
+ calculate_array_distance ( lhs, rhs, output, ctx, l1_distance) ;
93
+ }
94
+ ) ,
95
+ ) ;
96
+
97
+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type , _ , _ > (
98
+ "l1_distance" ,
99
+ |_, _, _| FunctionDomain :: MayThrow ,
100
+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type > (
101
+ |lhs, rhs, output, ctx| {
102
+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
103
+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
104
+ output. push ( F32 :: from ( 0.0 ) ) ;
105
+ return ;
101
106
}
107
+ calculate_array_distance ( lhs. column , rhs. column , output, ctx, l1_distance) ;
102
108
}
103
109
) ,
104
110
) ;
@@ -110,20 +116,22 @@ pub fn register(registry: &mut FunctionRegistry) {
110
116
|_, _, _| FunctionDomain :: MayThrow ,
111
117
vectorize_with_builder_2_arg :: < ArrayType < Float32Type > , ArrayType < Float32Type > , Float32Type > (
112
118
|lhs, rhs, output, ctx| {
113
- let l =
114
- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( lhs) } ;
115
- let r =
116
- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( rhs) } ;
117
-
118
- match l2_distance ( l. as_slice ( ) , r. as_slice ( ) ) {
119
- Ok ( dist) => {
120
- output. push ( F32 :: from ( dist) ) ;
121
- }
122
- Err ( err) => {
123
- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
124
- output. push ( F32 :: from ( 0.0 ) ) ;
125
- }
119
+ calculate_array_distance ( lhs, rhs, output, ctx, l2_distance) ;
120
+ }
121
+ ) ,
122
+ ) ;
123
+
124
+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type , _ , _ > (
125
+ "l2_distance" ,
126
+ |_, _, _| FunctionDomain :: MayThrow ,
127
+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type > (
128
+ |lhs, rhs, output, ctx| {
129
+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
130
+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
131
+ output. push ( F32 :: from ( 0.0 ) ) ;
132
+ return ;
126
133
}
134
+ calculate_array_distance ( lhs. column , rhs. column , output, ctx, l2_distance) ;
127
135
}
128
136
) ,
129
137
) ;
@@ -133,20 +141,22 @@ pub fn register(registry: &mut FunctionRegistry) {
133
141
|_, _, _| FunctionDomain :: MayThrow ,
134
142
vectorize_with_builder_2_arg :: < ArrayType < Float32Type > , ArrayType < Float32Type > , Float32Type > (
135
143
|lhs, rhs, output, ctx| {
136
- let l =
137
- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( lhs) } ;
138
- let r =
139
- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( rhs) } ;
140
-
141
- match inner_product ( l. as_slice ( ) , r. as_slice ( ) ) {
142
- Ok ( dist) => {
143
- output. push ( F32 :: from ( dist) ) ;
144
- }
145
- Err ( err) => {
146
- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
147
- output. push ( F32 :: from ( 0.0 ) ) ;
148
- }
144
+ calculate_array_distance ( lhs, rhs, output, ctx, inner_product) ;
145
+ }
146
+ ) ,
147
+ ) ;
148
+
149
+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type , _ , _ > (
150
+ "inner_product" ,
151
+ |_, _, _| FunctionDomain :: MayThrow ,
152
+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type > (
153
+ |lhs, rhs, output, ctx| {
154
+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
155
+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
156
+ output. push ( F32 :: from ( 0.0 ) ) ;
157
+ return ;
149
158
}
159
+ calculate_array_distance ( lhs. column , rhs. column , output, ctx, inner_product) ;
150
160
}
151
161
) ,
152
162
) ;
@@ -156,20 +166,22 @@ pub fn register(registry: &mut FunctionRegistry) {
156
166
|_, _, _| FunctionDomain :: MayThrow ,
157
167
vectorize_with_builder_2_arg :: < ArrayType < Float64Type > , ArrayType < Float64Type > , Float64Type > (
158
168
|lhs, rhs, output, ctx| {
159
- let l =
160
- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( lhs) } ;
161
- let r =
162
- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( rhs) } ;
163
-
164
- match cosine_distance_64 ( l. as_slice ( ) , r. as_slice ( ) ) {
165
- Ok ( dist) => {
166
- output. push ( F64 :: from ( dist) ) ;
167
- }
168
- Err ( err) => {
169
- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
170
- output. push ( F64 :: from ( 0.0 ) ) ;
171
- }
169
+ calculate_array_distance_64 ( lhs, rhs, output, ctx, cosine_distance_64) ;
170
+ }
171
+ ) ,
172
+ ) ;
173
+
174
+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type , _ , _ > (
175
+ "cosine_distance" ,
176
+ |_, _, _| FunctionDomain :: MayThrow ,
177
+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type > (
178
+ |lhs, rhs, output, ctx| {
179
+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
180
+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
181
+ output. push ( F64 :: from ( 0.0 ) ) ;
182
+ return ;
172
183
}
184
+ calculate_array_distance_64 ( lhs. column , rhs. column , output, ctx, cosine_distance_64) ;
173
185
}
174
186
) ,
175
187
) ;
@@ -179,20 +191,22 @@ pub fn register(registry: &mut FunctionRegistry) {
179
191
|_, _, _| FunctionDomain :: MayThrow ,
180
192
vectorize_with_builder_2_arg :: < ArrayType < Float64Type > , ArrayType < Float64Type > , Float64Type > (
181
193
|lhs, rhs, output, ctx| {
182
- let l =
183
- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( lhs) } ;
184
- let r =
185
- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( rhs) } ;
186
-
187
- match l1_distance_64 ( l. as_slice ( ) , r. as_slice ( ) ) {
188
- Ok ( dist) => {
189
- output. push ( F64 :: from ( dist) ) ;
190
- }
191
- Err ( err) => {
192
- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
193
- output. push ( F64 :: from ( 0.0 ) ) ;
194
- }
194
+ calculate_array_distance_64 ( lhs, rhs, output, ctx, l1_distance_64) ;
195
+ }
196
+ ) ,
197
+ ) ;
198
+
199
+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type , _ , _ > (
200
+ "l1_distance" ,
201
+ |_, _, _| FunctionDomain :: MayThrow ,
202
+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type > (
203
+ |lhs, rhs, output, ctx| {
204
+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
205
+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
206
+ output. push ( F64 :: from ( 0.0 ) ) ;
207
+ return ;
195
208
}
209
+ calculate_array_distance_64 ( lhs. column , rhs. column , output, ctx, l1_distance_64) ;
196
210
}
197
211
) ,
198
212
) ;
@@ -202,20 +216,22 @@ pub fn register(registry: &mut FunctionRegistry) {
202
216
|_, _, _| FunctionDomain :: MayThrow ,
203
217
vectorize_with_builder_2_arg :: < ArrayType < Float64Type > , ArrayType < Float64Type > , Float64Type > (
204
218
|lhs, rhs, output, ctx| {
205
- let l =
206
- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( lhs) } ;
207
- let r =
208
- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( rhs) } ;
209
-
210
- match l2_distance_64 ( l. as_slice ( ) , r. as_slice ( ) ) {
211
- Ok ( dist) => {
212
- output. push ( F64 :: from ( dist) ) ;
213
- }
214
- Err ( err) => {
215
- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
216
- output. push ( F64 :: from ( 0.0 ) ) ;
217
- }
219
+ calculate_array_distance_64 ( lhs, rhs, output, ctx, l2_distance_64) ;
220
+ }
221
+ ) ,
222
+ ) ;
223
+
224
+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type , _ , _ > (
225
+ "l2_distance" ,
226
+ |_, _, _| FunctionDomain :: MayThrow ,
227
+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type > (
228
+ |lhs, rhs, output, ctx| {
229
+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
230
+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
231
+ output. push ( F64 :: from ( 0.0 ) ) ;
232
+ return ;
218
233
}
234
+ calculate_array_distance_64 ( lhs. column , rhs. column , output, ctx, l2_distance_64) ;
219
235
}
220
236
) ,
221
237
) ;
@@ -225,20 +241,22 @@ pub fn register(registry: &mut FunctionRegistry) {
225
241
|_, _, _| FunctionDomain :: MayThrow ,
226
242
vectorize_with_builder_2_arg :: < ArrayType < Float64Type > , ArrayType < Float64Type > , Float64Type > (
227
243
|lhs, rhs, output, ctx| {
228
- let l =
229
- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( lhs) } ;
230
- let r =
231
- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( rhs) } ;
232
-
233
- match inner_product_64 ( l. as_slice ( ) , r. as_slice ( ) ) {
234
- Ok ( dist) => {
235
- output. push ( F64 :: from ( dist) ) ;
236
- }
237
- Err ( err) => {
238
- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
239
- output. push ( F64 :: from ( 0.0 ) ) ;
240
- }
244
+ calculate_array_distance_64 ( lhs, rhs, output, ctx, inner_product_64) ;
245
+ }
246
+ ) ,
247
+ ) ;
248
+
249
+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type , _ , _ > (
250
+ "inner_product" ,
251
+ |_, _, _| FunctionDomain :: MayThrow ,
252
+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type > (
253
+ |lhs, rhs, output, ctx| {
254
+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
255
+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
256
+ output. push ( F64 :: from ( 0.0 ) ) ;
257
+ return ;
241
258
}
259
+ calculate_array_distance_64 ( lhs. column , rhs. column , output, ctx, inner_product_64) ;
242
260
}
243
261
) ,
244
262
) ;
@@ -645,3 +663,49 @@ fn calculate_norm(value: &VectorScalarRef) -> f32 {
645
663
}
646
664
}
647
665
}
666
+
667
+ fn calculate_array_distance < F > (
668
+ lhs : Buffer < F32 > ,
669
+ rhs : Buffer < F32 > ,
670
+ output : & mut Vec < F32 > ,
671
+ ctx : & mut EvalContext ,
672
+ distance_fn : F ,
673
+ ) where
674
+ F : Fn ( & [ f32 ] , & [ f32 ] ) -> Result < f32 > ,
675
+ {
676
+ let l = unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( lhs) } ;
677
+ let r = unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( rhs) } ;
678
+
679
+ match distance_fn ( l. as_slice ( ) , r. as_slice ( ) ) {
680
+ Ok ( dist) => {
681
+ output. push ( F32 :: from ( dist) ) ;
682
+ }
683
+ Err ( err) => {
684
+ ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
685
+ output. push ( F32 :: from ( 0.0 ) ) ;
686
+ }
687
+ }
688
+ }
689
+
690
+ fn calculate_array_distance_64 < F > (
691
+ lhs : Buffer < F64 > ,
692
+ rhs : Buffer < F64 > ,
693
+ output : & mut Vec < F64 > ,
694
+ ctx : & mut EvalContext ,
695
+ distance_fn : F ,
696
+ ) where
697
+ F : Fn ( & [ f64 ] , & [ f64 ] ) -> Result < f64 > ,
698
+ {
699
+ let l = unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( lhs) } ;
700
+ let r = unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( rhs) } ;
701
+
702
+ match distance_fn ( l. as_slice ( ) , r. as_slice ( ) ) {
703
+ Ok ( dist) => {
704
+ output. push ( F64 :: from ( dist) ) ;
705
+ }
706
+ Err ( err) => {
707
+ ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
708
+ output. push ( F64 :: from ( 0.0 ) ) ;
709
+ }
710
+ }
711
+ }
0 commit comments