Skip to content

Commit c337381

Browse files
committed
HHH-18900 MariaDB Vector support
1 parent 297db57 commit c337381

File tree

12 files changed

+422
-11
lines changed

12 files changed

+422
-11
lines changed

databases/mariadb/matrix.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
55
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
66
*/
7-
jdbcDependency 'org.mariadb.jdbc:mariadb-java-client:3.4.0'
7+
jdbcDependency 'org.mariadb.jdbc:mariadb-java-client:3.5.1'

docker_db.sh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ mysql_8_2() {
9292
}
9393

9494
mariadb() {
95-
mariadb_11_4
95+
mariadb_11_7
9696
}
9797

9898
mariadb_wait_until_start()
@@ -138,6 +138,12 @@ mariadb_11_4() {
138138
mariadb_wait_until_start
139139
}
140140

141+
mariadb_11_7() {
142+
$CONTAINER_CLI rm -f mariadb || true
143+
$CONTAINER_CLI run --name mariadb -e MARIADB_USER=hibernate_orm_test -e MARIADB_PASSWORD=hibernate_orm_test -e MARIADB_DATABASE=hibernate_orm_test -e MARIADB_ROOT_PASSWORD=hibernate_orm_test -p3306:3306 -d ${DB_IMAGE_MARIADB_11_7:-docker.io/mariadb:11.7-rc} --character-set-server=utf8mb4 --collation-server=utf8mb4_bin --skip-character-set-client-handshake --lower_case_table_names=2
144+
mariadb_wait_until_start
145+
}
146+
141147
mariadb_verylatest() {
142148
$CONTAINER_CLI rm -f mariadb || true
143149
$CONTAINER_CLI run --name mariadb -e MARIADB_USER=hibernate_orm_test -e MARIADB_PASSWORD=hibernate_orm_test -e MARIADB_DATABASE=hibernate_orm_test -e MARIADB_ROOT_PASSWORD=hibernate_orm_test -p3306:3306 -d ${DB_IMAGE_MARIADB_VERYLATEST:-quay.io/mariadb-foundation/mariadb-devel:verylatest} --character-set-server=utf8mb4 --collation-server=utf8mb4_bin --skip-character-set-client-handshake --lower_case_table_names=2
@@ -996,6 +1002,7 @@ if [ -z ${1} ]; then
9961002
echo -e "\thana"
9971003
echo -e "\tmariadb"
9981004
echo -e "\tmariadb_verylatest"
1005+
echo -e "\tmariadb_11_7"
9991006
echo -e "\tmariadb_11_4"
10001007
echo -e "\tmariadb_11_1"
10011008
echo -e "\tmariadb_10_11"

hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -682,8 +682,8 @@ public class SqlTypes {
682682

683683
/**
684684
* A type code representing an {@code embedding vector} type for databases
685-
* like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} and
686-
* {@link org.hibernate.dialect.OracleDialect Oracle 23ai}.
685+
* like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL},
686+
* {@link org.hibernate.dialect.OracleDialect Oracle 23ai} and {@link org.hibernate.dialect.MariaDBDialect MariaDB}.
687687
* An embedding vector essentially is a {@code float[]} with a fixed size.
688688
*
689689
* @since 6.4

hibernate-core/src/test/java/org/hibernate/orm/test/temporal/MySQLTimestampFspFunctionTest.java

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*/
55
package org.hibernate.orm.test.temporal;
66

7+
import java.sql.Time;
78
import java.sql.Timestamp;
89

910
import org.hibernate.dialect.MySQLDialect;
@@ -33,7 +34,7 @@ public class MySQLTimestampFspFunctionTest {
3334

3435
@Test
3536
public void testTimeStampFunctions(SessionFactoryScope scope) {
36-
// current_timestamp(), localtime(), and localtimestamp() are synonyms for now(),
37+
// current_timestamp(), localtimestamp() are synonyms for now(),
3738
// which returns the time at which the statement began to execute.
3839
// the returned values for now(), current_timestamp(), localtime(), and
3940
// localtimestamp() should be the same.
@@ -42,7 +43,7 @@ public void testTimeStampFunctions(SessionFactoryScope scope) {
4243
scope.inSession(
4344
s -> {
4445
Query q = s.createQuery(
45-
"select now(), current_timestamp(), localtime(), localtimestamp(), sysdate()"
46+
"select now(), current_timestamp(), localtimestamp(), sysdate()"
4647
);
4748
Object[] oArray = (Object[]) q.uniqueResult();
4849
for ( Object o : oArray ) {
@@ -51,8 +52,27 @@ public void testTimeStampFunctions(SessionFactoryScope scope) {
5152
final Timestamp now = (Timestamp) oArray[0];
5253
assertEquals( now, oArray[1] );
5354
assertEquals( now, oArray[2] );
55+
assertTrue( now.compareTo( (Timestamp) oArray[3] ) <= 0 );
56+
}
57+
);
58+
}
59+
@Test
60+
public void testTimeFunctions(SessionFactoryScope scope) {
61+
// the returned TIME values for now(), current_timestamp(), localtime(), and
62+
// localtimestamp() should be the same.
63+
// sysdate() is the time at which the function itself is executed, so the
64+
// value returned for sysdate() should be different.
65+
scope.inSession(
66+
s -> {
67+
Query q = s.createQuery(
68+
"select CAST(now() AS TIME), CAST(current_timestamp() AS TIME), CAST(localtime() AS TIME), CAST(localtimestamp() AS TIME), CAST(sysdate() AS TIME)"
69+
);
70+
Object[] oArray = (Object[]) q.uniqueResult();
71+
final Time now = (Time) oArray[0];
72+
assertEquals( now, oArray[1] );
73+
assertEquals( now, oArray[2] );
5474
assertEquals( now, oArray[3] );
55-
assertTrue( now.compareTo( (Timestamp) oArray[4] ) <= 0 );
75+
assertTrue( now.compareTo( (Time) oArray[4] ) <= 0 );
5676
}
5777
);
5878
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* SPDX-License-Identifier: LGPL-2.1-or-later
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector;
6+
7+
import org.hibernate.dialect.Dialect;
8+
import org.hibernate.sql.ast.spi.SqlAppender;
9+
import org.hibernate.type.SqlTypes;
10+
import org.hibernate.type.descriptor.ValueBinder;
11+
import org.hibernate.type.descriptor.ValueExtractor;
12+
import org.hibernate.type.descriptor.WrapperOptions;
13+
import org.hibernate.type.descriptor.java.JavaType;
14+
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
15+
import org.hibernate.type.descriptor.jdbc.BasicBinder;
16+
import org.hibernate.type.descriptor.jdbc.BasicExtractor;
17+
import org.hibernate.type.descriptor.jdbc.JdbcType;
18+
import org.hibernate.type.spi.TypeConfiguration;
19+
20+
import java.sql.CallableStatement;
21+
import java.sql.PreparedStatement;
22+
import java.sql.ResultSet;
23+
import java.sql.SQLException;
24+
25+
public class BinaryVectorJdbcType extends ArrayJdbcType {
26+
27+
public BinaryVectorJdbcType(JdbcType elementJdbcType) {
28+
super( elementJdbcType );
29+
}
30+
31+
@Override
32+
public int getDefaultSqlTypeCode() {
33+
return SqlTypes.VECTOR;
34+
}
35+
36+
@Override
37+
public <T> JavaType<T> getJdbcRecommendedJavaTypeMapping(
38+
Integer precision,
39+
Integer scale,
40+
TypeConfiguration typeConfiguration) {
41+
return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class );
42+
}
43+
44+
@Override
45+
public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) {
46+
appender.append( writeExpression );
47+
}
48+
49+
@Override
50+
public <X> ValueExtractor<X> getExtractor(JavaType<X> javaTypeDescriptor) {
51+
return new BasicExtractor<>( javaTypeDescriptor, this ) {
52+
@Override
53+
protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException {
54+
return javaTypeDescriptor.wrap( rs.getObject( paramIndex, float[].class ), options );
55+
}
56+
57+
@Override
58+
protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException {
59+
return javaTypeDescriptor.wrap( statement.getObject( index, float[].class ), options );
60+
}
61+
62+
@Override
63+
protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException {
64+
return javaTypeDescriptor.wrap( statement.getObject( name, float[].class ), options );
65+
}
66+
67+
};
68+
}
69+
70+
@Override
71+
public <X> ValueBinder<X> getBinder(final JavaType<X> javaTypeDescriptor) {
72+
return new BasicBinder<>( javaTypeDescriptor, this ) {
73+
74+
@Override
75+
protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException {
76+
st.setObject( index, value );
77+
}
78+
79+
@Override
80+
protected void doBind(CallableStatement st, X value, String name, WrapperOptions options)
81+
throws SQLException {
82+
st.setObject( name, value, java.sql.Types.ARRAY );
83+
}
84+
85+
@Override
86+
public Object getBindValue(X value, WrapperOptions options) {
87+
return value;
88+
}
89+
};
90+
}
91+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* SPDX-License-Identifier: LGPL-2.1-or-later
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector;
6+
7+
import org.hibernate.boot.model.FunctionContributions;
8+
import org.hibernate.boot.model.FunctionContributor;
9+
import org.hibernate.dialect.Dialect;
10+
import org.hibernate.dialect.MariaDBDialect;
11+
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
12+
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
13+
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
14+
import org.hibernate.type.BasicType;
15+
import org.hibernate.type.BasicTypeRegistry;
16+
import org.hibernate.type.StandardBasicTypes;
17+
import org.hibernate.type.spi.TypeConfiguration;
18+
19+
public class MariaDBFunctionContributor implements FunctionContributor {
20+
21+
@Override
22+
public void contributeFunctions(FunctionContributions functionContributions) {
23+
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
24+
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
25+
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
26+
final Dialect dialect = functionContributions.getDialect();
27+
if ( dialect instanceof MariaDBDialect ) {
28+
final BasicType<Double> doubleType = basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE );
29+
30+
functionRegistry.patternDescriptorBuilder( "cosine_distance", "vec_distance_cosine(?1,?2)" )
31+
.setArgumentsValidator( StandardArgumentsValidators.composite(
32+
StandardArgumentsValidators.exactly( 2 ),
33+
VectorArgumentValidator.INSTANCE
34+
) )
35+
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
36+
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
37+
.register();
38+
functionRegistry.patternDescriptorBuilder( "euclidean_distance", "vec_distance_euclidean(?1,?2)" )
39+
.setArgumentsValidator( StandardArgumentsValidators.composite(
40+
StandardArgumentsValidators.exactly( 2 ),
41+
VectorArgumentValidator.INSTANCE
42+
) )
43+
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
44+
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
45+
.register();
46+
functionRegistry.registerAlternateKey( "l2_distance", "euclidean_distance" );
47+
48+
}
49+
}
50+
51+
@Override
52+
public int ordinal() {
53+
return 200;
54+
}
55+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* SPDX-License-Identifier: LGPL-2.1-or-later
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector;
6+
7+
import org.hibernate.boot.model.TypeContributions;
8+
import org.hibernate.boot.model.TypeContributor;
9+
import org.hibernate.dialect.Dialect;
10+
import org.hibernate.dialect.MariaDBDialect;
11+
import org.hibernate.engine.jdbc.Size;
12+
import org.hibernate.engine.jdbc.spi.JdbcServices;
13+
import org.hibernate.service.ServiceRegistry;
14+
import org.hibernate.type.BasicArrayType;
15+
import org.hibernate.type.BasicType;
16+
import org.hibernate.type.BasicTypeRegistry;
17+
import org.hibernate.type.SqlTypes;
18+
import org.hibernate.type.StandardBasicTypes;
19+
import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry;
20+
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
21+
import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry;
22+
import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl;
23+
import org.hibernate.type.spi.TypeConfiguration;
24+
25+
import java.lang.reflect.Type;
26+
27+
public class MariaDBTypeContributor implements TypeContributor {
28+
29+
private static final Type[] VECTOR_JAVA_TYPES = {
30+
Float[].class,
31+
float[].class
32+
};
33+
34+
@Override
35+
public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) {
36+
final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect();
37+
if ( dialect instanceof MariaDBDialect ) {
38+
final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration();
39+
final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry();
40+
final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry();
41+
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
42+
final BasicType<Float> floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT );
43+
final ArrayJdbcType vectorJdbcType = new BinaryVectorJdbcType( jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ) );
44+
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, vectorJdbcType );
45+
for ( Type vectorJavaType : VECTOR_JAVA_TYPES ) {
46+
basicTypeRegistry.register(
47+
new BasicArrayType<>(
48+
floatBasicType,
49+
vectorJdbcType,
50+
javaTypeRegistry.getDescriptor( vectorJavaType )
51+
),
52+
StandardBasicTypes.VECTOR.getName()
53+
);
54+
}
55+
typeConfiguration.getDdlTypeRegistry().addDescriptor(
56+
new DdlTypeImpl( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) {
57+
@Override
58+
public String getTypeName(Size size) {
59+
return getTypeName(
60+
size.getArrayLength() == null ? null : size.getArrayLength().longValue(),
61+
null,
62+
null
63+
);
64+
}
65+
}
66+
);
67+
}
68+
}
69+
}
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
org.hibernate.vector.PGVectorFunctionContributor
2-
org.hibernate.vector.OracleVectorFunctionContributor
2+
org.hibernate.vector.OracleVectorFunctionContributor
3+
org.hibernate.vector.MariaDBFunctionContributor
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
org.hibernate.vector.PGVectorTypeContributor
2-
org.hibernate.vector.OracleVectorTypeContributor
2+
org.hibernate.vector.OracleVectorTypeContributor
3+
org.hibernate.vector.MariaDBTypeContributor

0 commit comments

Comments
 (0)