@@ -20,84 +20,62 @@ public class OracleVectorFunctionContributor implements FunctionContributor {
2020
2121 @ Override
2222 public void contributeFunctions (FunctionContributions functionContributions ) {
23- final SqmFunctionRegistry functionRegistry = functionContributions .getFunctionRegistry ();
24- final TypeConfiguration typeConfiguration = functionContributions .getTypeConfiguration ();
25- final BasicTypeRegistry basicTypeRegistry = typeConfiguration .getBasicTypeRegistry ();
2623 final Dialect dialect = functionContributions .getDialect ();
27- if ( dialect instanceof OracleDialect ) {
28- final BasicType <Double > doubleType = basicTypeRegistry .resolve ( StandardBasicTypes .DOUBLE );
29- final BasicType <Integer > integerType = basicTypeRegistry .resolve ( StandardBasicTypes .INTEGER );
30- functionRegistry .patternDescriptorBuilder ( "cosine_distance" , "vector_distance(?1, ?2, COSINE)" )
31- .setArgumentsValidator ( StandardArgumentsValidators .composite (
32- StandardArgumentsValidators .exactly ( 2 ),
33- VectorArgumentValidator .INSTANCE
34- ) )
35- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
36- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( doubleType ) )
37- .register ();
38- functionRegistry .patternDescriptorBuilder ( "euclidean_distance" , "vector_distance(?1, ?2, EUCLIDEAN)" )
39- .setArgumentsValidator ( StandardArgumentsValidators .composite (
40- StandardArgumentsValidators .exactly ( 2 ),
41- VectorArgumentValidator .INSTANCE
42- ) )
43- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
44- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( doubleType ) )
45- .register ();
46- functionRegistry .registerAlternateKey ( "l2_distance" , "euclidean_distance" );
24+ if (dialect instanceof OracleDialect ) {
25+ final SqmFunctionRegistry functionRegistry = functionContributions .getFunctionRegistry ();
26+ final TypeConfiguration typeConfiguration = functionContributions .getTypeConfiguration ();
27+ final BasicTypeRegistry basicTypeRegistry = typeConfiguration .getBasicTypeRegistry ();
28+ final BasicType <Double > doubleType = basicTypeRegistry .resolve (StandardBasicTypes .DOUBLE );
29+ final BasicType <Integer > integerType = basicTypeRegistry .resolve (StandardBasicTypes .INTEGER );
4730
48- functionRegistry .patternDescriptorBuilder ( "l1_distance" , "vector_distance(?1, ?2, MANHATTAN)" )
49- .setArgumentsValidator ( StandardArgumentsValidators .composite (
50- StandardArgumentsValidators .exactly ( 2 ),
51- VectorArgumentValidator .INSTANCE
52- ) )
53- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
54- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( doubleType ) )
55- .register ();
56- functionRegistry .registerAlternateKey ( "taxicab_distance" , "l1_distance" );
31+ registerVectorDistanceFunction (functionRegistry , "cosine_distance" , "vector_distance(?1, ?2, COSINE)" , doubleType );
32+ registerVectorDistanceFunction (functionRegistry , "euclidean_distance" , "vector_distance(?1, ?2, EUCLIDEAN)" , doubleType );
33+ functionRegistry .registerAlternateKey ("l2_distance" , "euclidean_distance" );
5734
58- functionRegistry .patternDescriptorBuilder ( "negative_inner_product" , "vector_distance(?1, ?2, DOT)" )
59- .setArgumentsValidator ( StandardArgumentsValidators .composite (
60- StandardArgumentsValidators .exactly ( 2 ),
61- VectorArgumentValidator .INSTANCE
62- ) )
63- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
64- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( doubleType ) )
65- .register ();
66- functionRegistry .patternDescriptorBuilder ( "inner_product" , "vector_distance(?1, ?2, DOT)*-1" )
67- .setArgumentsValidator ( StandardArgumentsValidators .composite (
68- StandardArgumentsValidators .exactly ( 2 ),
69- VectorArgumentValidator .INSTANCE
70- ) )
71- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
72- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( doubleType ) )
73- .register ();
74- functionRegistry .patternDescriptorBuilder ( "hamming_distance" , "vector_distance(?1, ?2, HAMMING)" )
75- .setArgumentsValidator ( StandardArgumentsValidators .composite (
76- StandardArgumentsValidators .exactly ( 2 ),
77- VectorArgumentValidator .INSTANCE
78- ) )
79- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
80- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( doubleType ) )
81- .register ();
82- functionRegistry .namedDescriptorBuilder ( "vector_dims" )
83- .setArgumentsValidator ( StandardArgumentsValidators .composite (
84- StandardArgumentsValidators .exactly ( 1 ),
85- VectorArgumentValidator .INSTANCE
86- ) )
87- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
88- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( integerType ) )
89- .register ();
90- functionRegistry .namedDescriptorBuilder ( "vector_norm" )
91- .setArgumentsValidator ( StandardArgumentsValidators .composite (
92- StandardArgumentsValidators .exactly ( 1 ),
93- VectorArgumentValidator .INSTANCE
94- ) )
95- .setArgumentTypeResolver ( VectorArgumentTypeResolver .INSTANCE )
96- .setReturnTypeResolver ( StandardFunctionReturnTypeResolvers .invariant ( doubleType ) )
97- .register ();
35+ registerVectorDistanceFunction (functionRegistry , "l1_distance" , "vector_distance(?1, ?2, MANHATTAN)" , doubleType );
36+ functionRegistry .registerAlternateKey ("taxicab_distance" , "l1_distance" );
37+
38+ registerVectorDistanceFunction (functionRegistry , "negative_inner_product" , "vector_distance(?1, ?2, DOT)" , doubleType );
39+ registerVectorDistanceFunction (functionRegistry , "inner_product" , "vector_distance(?1, ?2, DOT)*-1" , doubleType );
40+ registerVectorDistanceFunction (functionRegistry , "hamming_distance" , "vector_distance(?1, ?2, HAMMING)" , doubleType );
41+
42+ registerNamedVectorFunction (functionRegistry , "vector_dims" , integerType , 1 );
43+ registerNamedVectorFunction (functionRegistry , "vector_norm" , doubleType , 1 );
9844 }
9945 }
10046
47+ private void registerVectorDistanceFunction (
48+ SqmFunctionRegistry functionRegistry ,
49+ String functionName ,
50+ String pattern ,
51+ BasicType <?> returnType ) {
52+
53+ functionRegistry .patternDescriptorBuilder (functionName , pattern )
54+ .setArgumentsValidator (StandardArgumentsValidators .composite (
55+ StandardArgumentsValidators .exactly (2 ),
56+ VectorArgumentValidator .INSTANCE
57+ ))
58+ .setArgumentTypeResolver (VectorArgumentTypeResolver .INSTANCE )
59+ .setReturnTypeResolver (StandardFunctionReturnTypeResolvers .invariant (returnType ))
60+ .register ();
61+ }
62+
63+ private void registerNamedVectorFunction (
64+ SqmFunctionRegistry functionRegistry ,
65+ String functionName ,
66+ BasicType <?> returnType ,
67+ int argumentCount ) {
68+
69+ functionRegistry .namedDescriptorBuilder (functionName )
70+ .setArgumentsValidator (StandardArgumentsValidators .composite (
71+ StandardArgumentsValidators .exactly (argumentCount ),
72+ VectorArgumentValidator .INSTANCE
73+ ))
74+ .setArgumentTypeResolver (VectorArgumentTypeResolver .INSTANCE )
75+ .setReturnTypeResolver (StandardFunctionReturnTypeResolvers .invariant (returnType ))
76+ .register ();
77+ }
78+
10179 @ Override
10280 public int ordinal () {
10381 return 200 ;
0 commit comments