|
8 | 8 | import org.hibernate.boot.model.FunctionContributor; |
9 | 9 | import org.hibernate.dialect.Dialect; |
10 | 10 | import org.hibernate.dialect.OracleDialect; |
11 | | -import org.hibernate.query.sqm.function.SqmFunctionRegistry; |
12 | | -import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; |
13 | | -import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; |
14 | | -import org.hibernate.type.BasicType; |
15 | | -import org.hibernate.type.BasicTypeRegistry; |
16 | | -import org.hibernate.type.StandardBasicTypes; |
17 | | -import org.hibernate.type.spi.TypeConfiguration; |
18 | 11 |
|
19 | 12 | public class OracleVectorFunctionContributor implements FunctionContributor { |
20 | 13 |
|
21 | 14 | @Override |
22 | 15 | public void contributeFunctions(FunctionContributions functionContributions) { |
23 | 16 | final Dialect dialect = functionContributions.getDialect(); |
24 | | - if (dialect instanceof OracleDialect) { |
25 | | - final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry(); |
26 | | - final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration(); |
27 | | - final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); |
28 | | - final BasicType<Double> doubleType = basicTypeRegistry.resolve(StandardBasicTypes.DOUBLE); |
29 | | - final BasicType<Integer> integerType = basicTypeRegistry.resolve(StandardBasicTypes.INTEGER); |
| 17 | + if ( dialect instanceof OracleDialect ) { |
| 18 | + final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions ); |
30 | 19 |
|
31 | | - registerVectorDistanceFunction(functionRegistry, "cosine_distance", "vector_distance(?1, ?2, COSINE)", doubleType); |
32 | | - registerVectorDistanceFunction(functionRegistry, "euclidean_distance", "vector_distance(?1, ?2, EUCLIDEAN)", doubleType); |
33 | | - functionRegistry.registerAlternateKey("l2_distance", "euclidean_distance"); |
| 20 | + vectorFunctionFactory.cosineDistance( "vector_distance(?1,?2,COSINE)" ); |
| 21 | + vectorFunctionFactory.euclideanDistance( "vector_distance(?1,?2,EUCLIDEAN)" ); |
| 22 | + vectorFunctionFactory.l1Distance( "vector_distance(?1,?2,MANHATTAN)" ); |
| 23 | + vectorFunctionFactory.hammingDistance( "vector_distance(?1,?2,HAMMING)" ); |
34 | 24 |
|
35 | | - registerVectorDistanceFunction(functionRegistry, "l1_distance", "vector_distance(?1, ?2, MANHATTAN)", doubleType); |
36 | | - functionRegistry.registerAlternateKey("taxicab_distance", "l1_distance"); |
| 25 | + vectorFunctionFactory.innerProduct( "vector_distance(?1,?2,DOT)*-1" ); |
| 26 | + vectorFunctionFactory.negativeInnerProduct( "vector_distance(?1,?2,DOT)" ); |
37 | 27 |
|
38 | | - registerVectorDistanceFunction(functionRegistry, "negative_inner_product", "vector_distance(?1, ?2, DOT)", doubleType); |
39 | | - registerVectorDistanceFunction(functionRegistry, "inner_product", "vector_distance(?1, ?2, DOT)*-1", doubleType); |
40 | | - registerVectorDistanceFunction(functionRegistry, "hamming_distance", "vector_distance(?1, ?2, HAMMING)", doubleType); |
41 | | - |
42 | | - registerNamedVectorFunction(functionRegistry, "vector_dims", integerType, 1); |
43 | | - registerNamedVectorFunction(functionRegistry, "vector_norm", doubleType, 1); |
| 28 | + vectorFunctionFactory.vectorDimensions(); |
| 29 | + vectorFunctionFactory.vectorNorm(); |
44 | 30 | } |
45 | 31 | } |
46 | 32 |
|
47 | | - private void registerVectorDistanceFunction( |
48 | | - SqmFunctionRegistry functionRegistry, |
49 | | - String functionName, |
50 | | - String pattern, |
51 | | - BasicType<?> returnType) { |
52 | | - |
53 | | - functionRegistry.patternDescriptorBuilder(functionName, pattern) |
54 | | - .setArgumentsValidator(StandardArgumentsValidators.composite( |
55 | | - StandardArgumentsValidators.exactly(2), |
56 | | - VectorArgumentValidator.INSTANCE |
57 | | - )) |
58 | | - .setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE) |
59 | | - .setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType)) |
60 | | - .register(); |
61 | | - } |
62 | | - |
63 | | - private void registerNamedVectorFunction( |
64 | | - SqmFunctionRegistry functionRegistry, |
65 | | - String functionName, |
66 | | - BasicType<?> returnType, |
67 | | - int argumentCount) { |
68 | | - |
69 | | - functionRegistry.namedDescriptorBuilder(functionName) |
70 | | - .setArgumentsValidator(StandardArgumentsValidators.composite( |
71 | | - StandardArgumentsValidators.exactly(argumentCount), |
72 | | - VectorArgumentValidator.INSTANCE |
73 | | - )) |
74 | | - .setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE) |
75 | | - .setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType)) |
76 | | - .register(); |
77 | | - } |
78 | | - |
79 | 33 | @Override |
80 | 34 | public int ordinal() { |
81 | 35 | return 200; |
|
0 commit comments