@@ -23,6 +23,7 @@ use databend_common_expression::types::Buffer;
2323use databend_common_expression:: types:: DataType ;
2424use databend_common_expression:: types:: Float32Type ;
2525use databend_common_expression:: types:: Float64Type ;
26+ use databend_common_expression:: types:: NullableType ;
2627use databend_common_expression:: types:: NumberColumn ;
2728use databend_common_expression:: types:: NumberDataType ;
2829use databend_common_expression:: types:: NumberScalar ;
@@ -34,6 +35,7 @@ use databend_common_expression::types::F64;
3435use databend_common_expression:: vectorize_with_builder_1_arg;
3536use databend_common_expression:: vectorize_with_builder_2_arg;
3637use databend_common_expression:: Column ;
38+ use databend_common_expression:: EvalContext ;
3739use databend_common_expression:: Function ;
3840use databend_common_expression:: FunctionDomain ;
3941use databend_common_expression:: FunctionEval ;
@@ -62,20 +64,22 @@ pub fn register(registry: &mut FunctionRegistry) {
6264 |_, _, _| FunctionDomain :: MayThrow ,
6365 vectorize_with_builder_2_arg :: < ArrayType < Float32Type > , ArrayType < Float32Type > , Float32Type > (
6466 |lhs, rhs, output, ctx| {
65- let l =
66- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( lhs) } ;
67- let r =
68- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( rhs) } ;
69-
70- match cosine_distance ( l. as_slice ( ) , r. as_slice ( ) ) {
71- Ok ( dist) => {
72- output. push ( F32 :: from ( dist) ) ;
73- }
74- Err ( err) => {
75- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
76- output. push ( F32 :: from ( 0.0 ) ) ;
77- }
67+ calculate_array_distance ( lhs, rhs, output, ctx, cosine_distance) ;
68+ }
69+ ) ,
70+ ) ;
71+
72+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type , _ , _ > (
73+ "cosine_distance" ,
74+ |_, _, _| FunctionDomain :: MayThrow ,
75+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type > (
76+ |lhs, rhs, output, ctx| {
77+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
78+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
79+ output. push ( F32 :: from ( 0.0 ) ) ;
80+ return ;
7881 }
82+ calculate_array_distance ( lhs. column , rhs. column , output, ctx, cosine_distance) ;
7983 }
8084 ) ,
8185 ) ;
@@ -85,20 +89,22 @@ pub fn register(registry: &mut FunctionRegistry) {
8589 |_, _, _| FunctionDomain :: MayThrow ,
8690 vectorize_with_builder_2_arg :: < ArrayType < Float32Type > , ArrayType < Float32Type > , Float32Type > (
8791 |lhs, rhs, output, ctx| {
88- let l =
89- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( lhs) } ;
90- let r =
91- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( rhs) } ;
92-
93- match l1_distance ( l. as_slice ( ) , r. as_slice ( ) ) {
94- Ok ( dist) => {
95- output. push ( F32 :: from ( dist) ) ;
96- }
97- Err ( err) => {
98- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
99- output. push ( F32 :: from ( 0.0 ) ) ;
100- }
92+ calculate_array_distance ( lhs, rhs, output, ctx, l1_distance) ;
93+ }
94+ ) ,
95+ ) ;
96+
97+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type , _ , _ > (
98+ "l1_distance" ,
99+ |_, _, _| FunctionDomain :: MayThrow ,
100+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type > (
101+ |lhs, rhs, output, ctx| {
102+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
103+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
104+ output. push ( F32 :: from ( 0.0 ) ) ;
105+ return ;
101106 }
107+ calculate_array_distance ( lhs. column , rhs. column , output, ctx, l1_distance) ;
102108 }
103109 ) ,
104110 ) ;
@@ -110,20 +116,22 @@ pub fn register(registry: &mut FunctionRegistry) {
110116 |_, _, _| FunctionDomain :: MayThrow ,
111117 vectorize_with_builder_2_arg :: < ArrayType < Float32Type > , ArrayType < Float32Type > , Float32Type > (
112118 |lhs, rhs, output, ctx| {
113- let l =
114- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( lhs) } ;
115- let r =
116- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( rhs) } ;
117-
118- match l2_distance ( l. as_slice ( ) , r. as_slice ( ) ) {
119- Ok ( dist) => {
120- output. push ( F32 :: from ( dist) ) ;
121- }
122- Err ( err) => {
123- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
124- output. push ( F32 :: from ( 0.0 ) ) ;
125- }
119+ calculate_array_distance ( lhs, rhs, output, ctx, l2_distance) ;
120+ }
121+ ) ,
122+ ) ;
123+
124+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type , _ , _ > (
125+ "l2_distance" ,
126+ |_, _, _| FunctionDomain :: MayThrow ,
127+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type > (
128+ |lhs, rhs, output, ctx| {
129+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
130+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
131+ output. push ( F32 :: from ( 0.0 ) ) ;
132+ return ;
126133 }
134+ calculate_array_distance ( lhs. column , rhs. column , output, ctx, l2_distance) ;
127135 }
128136 ) ,
129137 ) ;
@@ -133,20 +141,22 @@ pub fn register(registry: &mut FunctionRegistry) {
133141 |_, _, _| FunctionDomain :: MayThrow ,
134142 vectorize_with_builder_2_arg :: < ArrayType < Float32Type > , ArrayType < Float32Type > , Float32Type > (
135143 |lhs, rhs, output, ctx| {
136- let l =
137- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( lhs) } ;
138- let r =
139- unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( rhs) } ;
140-
141- match inner_product ( l. as_slice ( ) , r. as_slice ( ) ) {
142- Ok ( dist) => {
143- output. push ( F32 :: from ( dist) ) ;
144- }
145- Err ( err) => {
146- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
147- output. push ( F32 :: from ( 0.0 ) ) ;
148- }
144+ calculate_array_distance ( lhs, rhs, output, ctx, inner_product) ;
145+ }
146+ ) ,
147+ ) ;
148+
149+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type , _ , _ > (
150+ "inner_product" ,
151+ |_, _, _| FunctionDomain :: MayThrow ,
152+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float32Type > > , ArrayType < NullableType < Float32Type > > , Float32Type > (
153+ |lhs, rhs, output, ctx| {
154+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
155+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
156+ output. push ( F32 :: from ( 0.0 ) ) ;
157+ return ;
149158 }
159+ calculate_array_distance ( lhs. column , rhs. column , output, ctx, inner_product) ;
150160 }
151161 ) ,
152162 ) ;
@@ -156,20 +166,22 @@ pub fn register(registry: &mut FunctionRegistry) {
156166 |_, _, _| FunctionDomain :: MayThrow ,
157167 vectorize_with_builder_2_arg :: < ArrayType < Float64Type > , ArrayType < Float64Type > , Float64Type > (
158168 |lhs, rhs, output, ctx| {
159- let l =
160- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( lhs) } ;
161- let r =
162- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( rhs) } ;
163-
164- match cosine_distance_64 ( l. as_slice ( ) , r. as_slice ( ) ) {
165- Ok ( dist) => {
166- output. push ( F64 :: from ( dist) ) ;
167- }
168- Err ( err) => {
169- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
170- output. push ( F64 :: from ( 0.0 ) ) ;
171- }
169+ calculate_array_distance_64 ( lhs, rhs, output, ctx, cosine_distance_64) ;
170+ }
171+ ) ,
172+ ) ;
173+
174+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type , _ , _ > (
175+ "cosine_distance" ,
176+ |_, _, _| FunctionDomain :: MayThrow ,
177+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type > (
178+ |lhs, rhs, output, ctx| {
179+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
180+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
181+ output. push ( F64 :: from ( 0.0 ) ) ;
182+ return ;
172183 }
184+ calculate_array_distance_64 ( lhs. column , rhs. column , output, ctx, cosine_distance_64) ;
173185 }
174186 ) ,
175187 ) ;
@@ -179,20 +191,22 @@ pub fn register(registry: &mut FunctionRegistry) {
179191 |_, _, _| FunctionDomain :: MayThrow ,
180192 vectorize_with_builder_2_arg :: < ArrayType < Float64Type > , ArrayType < Float64Type > , Float64Type > (
181193 |lhs, rhs, output, ctx| {
182- let l =
183- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( lhs) } ;
184- let r =
185- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( rhs) } ;
186-
187- match l1_distance_64 ( l. as_slice ( ) , r. as_slice ( ) ) {
188- Ok ( dist) => {
189- output. push ( F64 :: from ( dist) ) ;
190- }
191- Err ( err) => {
192- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
193- output. push ( F64 :: from ( 0.0 ) ) ;
194- }
194+ calculate_array_distance_64 ( lhs, rhs, output, ctx, l1_distance_64) ;
195+ }
196+ ) ,
197+ ) ;
198+
199+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type , _ , _ > (
200+ "l1_distance" ,
201+ |_, _, _| FunctionDomain :: MayThrow ,
202+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type > (
203+ |lhs, rhs, output, ctx| {
204+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
205+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
206+ output. push ( F64 :: from ( 0.0 ) ) ;
207+ return ;
195208 }
209+ calculate_array_distance_64 ( lhs. column , rhs. column , output, ctx, l1_distance_64) ;
196210 }
197211 ) ,
198212 ) ;
@@ -202,20 +216,22 @@ pub fn register(registry: &mut FunctionRegistry) {
202216 |_, _, _| FunctionDomain :: MayThrow ,
203217 vectorize_with_builder_2_arg :: < ArrayType < Float64Type > , ArrayType < Float64Type > , Float64Type > (
204218 |lhs, rhs, output, ctx| {
205- let l =
206- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( lhs) } ;
207- let r =
208- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( rhs) } ;
209-
210- match l2_distance_64 ( l. as_slice ( ) , r. as_slice ( ) ) {
211- Ok ( dist) => {
212- output. push ( F64 :: from ( dist) ) ;
213- }
214- Err ( err) => {
215- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
216- output. push ( F64 :: from ( 0.0 ) ) ;
217- }
219+ calculate_array_distance_64 ( lhs, rhs, output, ctx, l2_distance_64) ;
220+ }
221+ ) ,
222+ ) ;
223+
224+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type , _ , _ > (
225+ "l2_distance" ,
226+ |_, _, _| FunctionDomain :: MayThrow ,
227+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type > (
228+ |lhs, rhs, output, ctx| {
229+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
230+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
231+ output. push ( F64 :: from ( 0.0 ) ) ;
232+ return ;
218233 }
234+ calculate_array_distance_64 ( lhs. column , rhs. column , output, ctx, l2_distance_64) ;
219235 }
220236 ) ,
221237 ) ;
@@ -225,20 +241,22 @@ pub fn register(registry: &mut FunctionRegistry) {
225241 |_, _, _| FunctionDomain :: MayThrow ,
226242 vectorize_with_builder_2_arg :: < ArrayType < Float64Type > , ArrayType < Float64Type > , Float64Type > (
227243 |lhs, rhs, output, ctx| {
228- let l =
229- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( lhs) } ;
230- let r =
231- unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( rhs) } ;
232-
233- match inner_product_64 ( l. as_slice ( ) , r. as_slice ( ) ) {
234- Ok ( dist) => {
235- output. push ( F64 :: from ( dist) ) ;
236- }
237- Err ( err) => {
238- ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
239- output. push ( F64 :: from ( 0.0 ) ) ;
240- }
244+ calculate_array_distance_64 ( lhs, rhs, output, ctx, inner_product_64) ;
245+ }
246+ ) ,
247+ ) ;
248+
249+ registry. register_passthrough_nullable_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type , _ , _ > (
250+ "inner_product" ,
251+ |_, _, _| FunctionDomain :: MayThrow ,
252+ vectorize_with_builder_2_arg :: < ArrayType < NullableType < Float64Type > > , ArrayType < NullableType < Float64Type > > , Float64Type > (
253+ |lhs, rhs, output, ctx| {
254+ if lhs. validity . null_count ( ) > 0 || rhs. validity . null_count ( ) > 0 {
255+ ctx. set_error ( output. len ( ) , "Vector contain null values" ) ;
256+ output. push ( F64 :: from ( 0.0 ) ) ;
257+ return ;
241258 }
259+ calculate_array_distance_64 ( lhs. column , rhs. column , output, ctx, inner_product_64) ;
242260 }
243261 ) ,
244262 ) ;
@@ -645,3 +663,49 @@ fn calculate_norm(value: &VectorScalarRef) -> f32 {
645663 }
646664 }
647665}
666+
667+ fn calculate_array_distance < F > (
668+ lhs : Buffer < F32 > ,
669+ rhs : Buffer < F32 > ,
670+ output : & mut Vec < F32 > ,
671+ ctx : & mut EvalContext ,
672+ distance_fn : F ,
673+ ) where
674+ F : Fn ( & [ f32 ] , & [ f32 ] ) -> Result < f32 > ,
675+ {
676+ let l = unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( lhs) } ;
677+ let r = unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( rhs) } ;
678+
679+ match distance_fn ( l. as_slice ( ) , r. as_slice ( ) ) {
680+ Ok ( dist) => {
681+ output. push ( F32 :: from ( dist) ) ;
682+ }
683+ Err ( err) => {
684+ ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
685+ output. push ( F32 :: from ( 0.0 ) ) ;
686+ }
687+ }
688+ }
689+
690+ fn calculate_array_distance_64 < F > (
691+ lhs : Buffer < F64 > ,
692+ rhs : Buffer < F64 > ,
693+ output : & mut Vec < F64 > ,
694+ ctx : & mut EvalContext ,
695+ distance_fn : F ,
696+ ) where
697+ F : Fn ( & [ f64 ] , & [ f64 ] ) -> Result < f64 > ,
698+ {
699+ let l = unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( lhs) } ;
700+ let r = unsafe { std:: mem:: transmute :: < Buffer < F64 > , Buffer < f64 > > ( rhs) } ;
701+
702+ match distance_fn ( l. as_slice ( ) , r. as_slice ( ) ) {
703+ Ok ( dist) => {
704+ output. push ( F64 :: from ( dist) ) ;
705+ }
706+ Err ( err) => {
707+ ctx. set_error ( output. len ( ) , err. to_string ( ) ) ;
708+ output. push ( F64 :: from ( 0.0 ) ) ;
709+ }
710+ }
711+ }
0 commit comments