Skip to content

Commit 1451f56

Browse files
committed
HHH-18973 Cleanup vector module and add MySQL vector support
Also add support for optional cast patterns to JdbcType to avoid having to touch Dialect for new JdbcType and DdlType.
1 parent 5ded8ea commit 1451f56

33 files changed

+971
-1035
lines changed

documentation/src/main/asciidoc/userguide/chapters/query/extensions/Vector.adoc

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,21 @@ The Hibernate ORM Vector module contains support for mathematical vector types a
1212
This is useful for AI/ML topics like vector similarity search and Retrieval-Augmented Generation (RAG).
1313
The module comes with support for a special `vector` data type that essentially represents an array of bytes, floats, or doubles.
1414

15-
So far, both the PostgreSQL extension `pgvector` and the Oracle database 23ai+ `AI Vector Search` feature are supported, but in theory,
16-
the vector specific functions could be implemented to work with every database that supports arrays.
15+
Currently, the following databases are supported:
1716

18-
For further details, refer to the https://github.com/pgvector/pgvector#querying[pgvector documentation] or the https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/overview-node.html[AI Vector Search documentation].
17+
* PostgreSQL 13+ through the https://github.com/pgvector/pgvector#querying[`pgvector` extension]
18+
* https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/overview-node.html[Oracle database 23ai+]
19+
* https://mariadb.com/docs/server/reference/sql-structure/vectors/vector-overview[MariaDB 11.7+]
20+
* https://dev.mysql.com/doc/refman/9.4/en/vector-functions.html[MySQL 9.0+]
21+
22+
In theory, the vector-specific functions could be implemented to work with every database that supports arrays.
23+
24+
[WARNING]
25+
====
26+
Per the https://dev.mysql.com/doc/refman/9.4/en/vector-functions.html#function_distance[MySQL documentation],
27+
the various vector distance functions for MySQL only work on MySQL cloud offerings like
28+
https://dev.mysql.com/doc/heatwave/en/mys-hw-about-heatwave.html[HeatWave MySQL on OCI].
29+
====
1930

2031
[[vector-module-setup]]
2132
=== Setup
@@ -57,7 +68,7 @@ As Oracle AI Vector Search supports different types of elements (to ensure bette
5768
====
5869
[source, java, indent=0]
5970
----
60-
include::{example-dir-vector}/PGVectorTest.java[tags=usage-example]
71+
include::{example-dir-vector}/FloatVectorTest.java[tags=usage-example]
6172
----
6273
====
6374

@@ -113,7 +124,7 @@ which is `1 - inner_product( v1, v2 ) / ( vector_norm( v1 ) * vector_norm( v2 )
113124
====
114125
[source, java, indent=0]
115126
----
116-
include::{example-dir-vector}/PGVectorTest.java[tags=cosine-distance-example]
127+
include::{example-dir-vector}/FloatVectorTest.java[tags=cosine-distance-example]
117128
----
118129
====
119130

@@ -128,7 +139,7 @@ The `l2_distance()` function is an alias.
128139
====
129140
[source, java, indent=0]
130141
----
131-
include::{example-dir-vector}/PGVectorTest.java[tags=euclidean-distance-example]
142+
include::{example-dir-vector}/FloatVectorTest.java[tags=euclidean-distance-example]
132143
----
133144
====
134145

@@ -143,7 +154,7 @@ The `l1_distance()` function is an alias.
143154
====
144155
[source, java, indent=0]
145156
----
146-
include::{example-dir-vector}/PGVectorTest.java[tags=taxicab-distance-example]
157+
include::{example-dir-vector}/FloatVectorTest.java[tags=taxicab-distance-example]
147158
----
148159
====
149160

@@ -158,7 +169,7 @@ and the `inner_product()` function as well, but multiplies the result time `-1`.
158169
====
159170
[source, java, indent=0]
160171
----
161-
include::{example-dir-vector}/PGVectorTest.java[tags=inner-product-example]
172+
include::{example-dir-vector}/FloatVectorTest.java[tags=inner-product-example]
162173
----
163174
====
164175

@@ -171,7 +182,7 @@ Determines the dimensions of a vector.
171182
====
172183
[source, java, indent=0]
173184
----
174-
include::{example-dir-vector}/PGVectorTest.java[tags=vector-dims-example]
185+
include::{example-dir-vector}/FloatVectorTest.java[tags=vector-dims-example]
175186
----
176187
====
177188

@@ -185,7 +196,7 @@ which is `sqrt( sum( v_i^2 ) )`.
185196
====
186197
[source, java, indent=0]
187198
----
188-
include::{example-dir-vector}/PGVectorTest.java[tags=vector-norm-example]
199+
include::{example-dir-vector}/FloatVectorTest.java[tags=vector-norm-example]
189200
----
190201
====
191202

hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,14 @@ public void render(
7777
renderCastArrayToString( sqlAppender, arguments.get( 0 ), dialect, walker );
7878
}
7979
else {
80-
new PatternRenderer( dialect.castPattern( sourceType, targetType ) )
81-
.render( sqlAppender, arguments, walker );
80+
String castPattern = targetJdbcMapping.getJdbcType().castFromPattern( sourceMapping );
81+
if ( castPattern == null ) {
82+
castPattern = sourceMapping.getJdbcType().castToPattern( targetJdbcMapping );
83+
if ( castPattern == null ) {
84+
castPattern = dialect.castPattern( sourceType, targetType );
85+
}
86+
}
87+
new PatternRenderer( castPattern ).render( sqlAppender, arguments, walker );
8288
}
8389
}
8490

hibernate-core/src/main/java/org/hibernate/dialect/function/SumReturnTypeResolver.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ public ReturnableType<?> resolveFunctionReturnType(
9090
case NUMERIC:
9191
return BigInteger.class.isAssignableFrom( basicType.getJavaType() ) ? bigIntegerType : bigDecimalType;
9292
case VECTOR:
93+
case VECTOR_FLOAT32:
94+
case VECTOR_FLOAT64:
95+
case VECTOR_INT8:
9396
return basicType;
9497
}
9598
return bigDecimalType;
@@ -123,6 +126,9 @@ public BasicValuedMapping resolveFunctionReturnType(
123126
final Class<?> argTypeClass = jdbcMapping.getJavaTypeDescriptor().getJavaTypeClass();
124127
return BigInteger.class.isAssignableFrom( argTypeClass ) ? bigIntegerType : bigDecimalType;
125128
case VECTOR:
129+
case VECTOR_FLOAT32:
130+
case VECTOR_FLOAT64:
131+
case VECTOR_INT8:
126132
return (BasicValuedMapping) jdbcMapping;
127133
}
128134
return bigDecimalType;

hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
import java.sql.SQLException;
1010
import java.sql.Types;
1111

12+
import org.checkerframework.checker.nullness.qual.Nullable;
1213
import org.hibernate.Incubating;
1314
import org.hibernate.boot.model.relational.Database;
1415
import org.hibernate.dialect.Dialect;
1516
import org.hibernate.engine.jdbc.Size;
17+
import org.hibernate.metamodel.mapping.JdbcMapping;
1618
import org.hibernate.query.sqm.CastType;
1719
import org.hibernate.sql.ast.spi.SqlAppender;
1820
import org.hibernate.sql.ast.spi.StringBuilderSqlAppender;
@@ -367,6 +369,30 @@ default String getExtraCreateTableInfo(JavaType<?> javaType, String columnName,
367369
return "";
368370
}
369371

372+
/**
373+
* Returns the cast pattern from the given source type to this type, or {@code null} if not possible.
374+
*
375+
* @param sourceMapping The source type
376+
* @return The cast pattern or null
377+
* @since 7.1
378+
*/
379+
@Incubating
380+
default @Nullable String castFromPattern(JdbcMapping sourceMapping) {
381+
return null;
382+
}
383+
384+
/**
385+
* Returns the cast pattern from this type to the given target type, or {@code null} if not possible.
386+
*
387+
* @param targetJdbcMapping The target type
388+
* @return The cast pattern or null
389+
* @since 7.1
390+
*/
391+
@Incubating
392+
default @Nullable String castToPattern(JdbcMapping targetJdbcMapping) {
393+
return null;
394+
}
395+
370396
@Incubating
371397
default boolean isComparable() {
372398
final int code = getDefaultSqlTypeCode();

hibernate-testing/src/main/java/org/hibernate/testing/orm/junit/DialectFeatureChecks.java

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
import org.hibernate.boot.internal.MetadataBuilderImpl;
1414
import org.hibernate.boot.internal.NamedProcedureCallDefinitionImpl;
1515
import org.hibernate.boot.model.FunctionContributions;
16+
import org.hibernate.boot.model.FunctionContributor;
1617
import org.hibernate.boot.model.IdentifierGeneratorDefinition;
1718
import org.hibernate.boot.model.NamedEntityGraphDefinition;
1819
import org.hibernate.boot.model.TypeContributions;
20+
import org.hibernate.boot.model.TypeContributor;
1921
import org.hibernate.boot.model.TypeDefinition;
2022
import org.hibernate.boot.model.TypeDefinitionRegistry;
2123
import org.hibernate.boot.model.convert.spi.ConverterAutoApplyHandler;
@@ -97,6 +99,7 @@
9799
import org.hibernate.type.descriptor.java.StringJavaType;
98100
import org.hibernate.type.descriptor.jdbc.JdbcType;
99101
import org.hibernate.type.descriptor.jdbc.VarcharJdbcType;
102+
import org.hibernate.type.descriptor.sql.spi.DdlTypeRegistry;
100103
import org.hibernate.type.internal.BasicTypeImpl;
101104
import org.hibernate.type.spi.TypeConfiguration;
102105
import org.hibernate.usertype.CompositeUserType;
@@ -105,6 +108,7 @@
105108
import java.util.HashMap;
106109
import java.util.List;
107110
import java.util.Map;
111+
import java.util.ServiceLoader;
108112
import java.util.Set;
109113
import java.util.UUID;
110114
import java.util.function.Consumer;
@@ -1081,6 +1085,66 @@ public boolean apply(Dialect dialect) {
10811085
}
10821086
}
10831087

1088+
public static class SupportsVectorType implements DialectFeatureCheck {
1089+
public boolean apply(Dialect dialect) {
1090+
return definesDdlType( dialect, SqlTypes.VECTOR );
1091+
}
1092+
}
1093+
1094+
public static class SupportsDoubleVectorType implements DialectFeatureCheck {
1095+
public boolean apply(Dialect dialect) {
1096+
return definesDdlType( dialect, SqlTypes.VECTOR_FLOAT64 );
1097+
}
1098+
}
1099+
1100+
public static class SupportsByteVectorType implements DialectFeatureCheck {
1101+
public boolean apply(Dialect dialect) {
1102+
return definesDdlType( dialect, SqlTypes.VECTOR_INT8 );
1103+
}
1104+
}
1105+
1106+
public static class SupportsCosineDistance implements DialectFeatureCheck {
1107+
public boolean apply(Dialect dialect) {
1108+
return definesFunction( dialect, "cosine_distance" );
1109+
}
1110+
}
1111+
1112+
public static class SupportsEuclideanDistance implements DialectFeatureCheck {
1113+
public boolean apply(Dialect dialect) {
1114+
return definesFunction( dialect, "euclidean_distance" );
1115+
}
1116+
}
1117+
1118+
public static class SupportsTaxicabDistance implements DialectFeatureCheck {
1119+
public boolean apply(Dialect dialect) {
1120+
return definesFunction( dialect, "taxicab_distance" );
1121+
}
1122+
}
1123+
1124+
public static class SupportsHammingDistance implements DialectFeatureCheck {
1125+
public boolean apply(Dialect dialect) {
1126+
return definesFunction( dialect, "hamming_distance" );
1127+
}
1128+
}
1129+
1130+
public static class SupportsInnerProduct implements DialectFeatureCheck {
1131+
public boolean apply(Dialect dialect) {
1132+
return definesFunction( dialect, "inner_product" );
1133+
}
1134+
}
1135+
1136+
public static class SupportsVectorDims implements DialectFeatureCheck {
1137+
public boolean apply(Dialect dialect) {
1138+
return definesFunction( dialect, "vector_dims" );
1139+
}
1140+
}
1141+
1142+
public static class SupportsVectorNorm implements DialectFeatureCheck {
1143+
public boolean apply(Dialect dialect) {
1144+
return definesFunction( dialect, "vector_norm" );
1145+
}
1146+
}
1147+
10841148
public static class IsJtds implements DialectFeatureCheck {
10851149
public boolean apply(Dialect dialect) {
10861150
return dialect instanceof SybaseDialect && ( (SybaseDialect) dialect ).getDriverKind() == SybaseDriverKind.JTDS;
@@ -1146,7 +1210,7 @@ public boolean apply(Dialect dialect) {
11461210
}
11471211
}
11481212

1149-
private static final HashMap<Dialect, SqmFunctionRegistry> FUNCTION_REGISTRIES = new HashMap<>();
1213+
private static final HashMap<Dialect, FakeFunctionContributions> FUNCTION_CONTRIBUTIONS = new HashMap<>();
11501214

11511215
public static boolean definesFunction(Dialect dialect, String functionName) {
11521216
return getSqmFunctionRegistry( dialect ).findFunctionDescriptor( functionName ) != null;
@@ -1156,6 +1220,11 @@ public static boolean definesSetReturningFunction(Dialect dialect, String functi
11561220
return getSqmFunctionRegistry( dialect ).findSetReturningFunctionDescriptor( functionName ) != null;
11571221
}
11581222

1223+
public static boolean definesDdlType(Dialect dialect, int typeCode) {
1224+
final DdlTypeRegistry ddlTypeRegistry = getFunctionContributions( dialect ).typeConfiguration.getDdlTypeRegistry();
1225+
return ddlTypeRegistry.getDescriptor( typeCode ) != null;
1226+
}
1227+
11591228
public static class SupportsSubqueryInSelect implements DialectFeatureCheck {
11601229
@Override
11611230
public boolean apply(Dialect dialect) {
@@ -1177,24 +1246,33 @@ public boolean apply(Dialect dialect) {
11771246
}
11781247
}
11791248

1180-
11811249
private static SqmFunctionRegistry getSqmFunctionRegistry(Dialect dialect) {
1182-
SqmFunctionRegistry sqmFunctionRegistry = FUNCTION_REGISTRIES.get( dialect );
1183-
if ( sqmFunctionRegistry == null ) {
1250+
return getFunctionContributions( dialect ).functionRegistry;
1251+
}
1252+
1253+
private static FakeFunctionContributions getFunctionContributions(Dialect dialect) {
1254+
FakeFunctionContributions functionContributions = FUNCTION_CONTRIBUTIONS.get( dialect );
1255+
if ( functionContributions == null ) {
11841256
final TypeConfiguration typeConfiguration = new TypeConfiguration();
11851257
final SqmFunctionRegistry functionRegistry = new SqmFunctionRegistry();
11861258
typeConfiguration.scope( new FakeMetadataBuildingContext( typeConfiguration, functionRegistry ) );
11871259
final FakeTypeContributions typeContributions = new FakeTypeContributions( typeConfiguration );
1188-
final FakeFunctionContributions functionContributions = new FakeFunctionContributions(
1260+
functionContributions = new FakeFunctionContributions(
11891261
dialect,
11901262
typeConfiguration,
11911263
functionRegistry
11921264
);
11931265
dialect.contribute( typeContributions, typeConfiguration.getServiceRegistry() );
11941266
dialect.initializeFunctionRegistry( functionContributions );
1195-
FUNCTION_REGISTRIES.put( dialect, sqmFunctionRegistry = functionContributions.functionRegistry );
1267+
for ( TypeContributor typeContributor : ServiceLoader.load( TypeContributor.class ) ) {
1268+
typeContributor.contribute( typeContributions, typeConfiguration.getServiceRegistry() );
1269+
}
1270+
for ( FunctionContributor functionContributor : ServiceLoader.load( FunctionContributor.class ) ) {
1271+
functionContributor.contributeFunctions( functionContributions );
1272+
}
1273+
FUNCTION_CONTRIBUTIONS.put( dialect, functionContributions );
11961274
}
1197-
return sqmFunctionRegistry;
1275+
return functionContributions;
11981276
}
11991277

12001278
public static class FakeTypeContributions implements TypeContributions {

hibernate-vector/src/main/java/org/hibernate/vector/AbstractOracleVectorJdbcType.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import java.sql.ResultSet;
1010
import java.sql.SQLException;
1111

12+
import org.checkerframework.checker.nullness.qual.Nullable;
1213
import org.hibernate.dialect.Dialect;
14+
import org.hibernate.metamodel.mapping.JdbcMapping;
1315
import org.hibernate.sql.ast.spi.SqlAppender;
1416
import org.hibernate.type.SqlTypes;
1517
import org.hibernate.type.descriptor.ValueBinder;
@@ -43,13 +45,13 @@ public AbstractOracleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSu
4345
this.isVectorSupported = isVectorSupported;
4446
}
4547

46-
public abstract void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect);
47-
4848
@Override
49-
public int getDefaultSqlTypeCode() {
50-
return SqlTypes.VECTOR;
49+
public @Nullable String castToPattern(JdbcMapping targetJdbcMapping) {
50+
return targetJdbcMapping.getJdbcType().isStringLike() ? "from_vector(?1 returning ?2)" : null;
5151
}
5252

53+
public abstract void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect);
54+
5355
@Override
5456
public <T> JdbcLiteralFormatter<T> getJdbcLiteralFormatter(JavaType<T> javaTypeDescriptor) {
5557
final JavaType<T> elementJavaType;

hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public class MariaDBFunctionContributor implements FunctionContributor {
1313
@Override
1414
public void contributeFunctions(FunctionContributions functionContributions) {
1515
final Dialect dialect = functionContributions.getDialect();
16-
if ( dialect instanceof MariaDBDialect ) {
16+
if ( dialect instanceof MariaDBDialect && dialect.getVersion().isSameOrAfter( 11, 7 ) ) {
1717
final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions );
1818

1919
vectorFunctionFactory.cosineDistance( "vec_distance_cosine(?1,?2)" );

0 commit comments

Comments
 (0)