Skip to content

HHH-18973, HHH-19679 hibernate-vector module enhancements #10685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docker_db.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
#! /bin/bash

if command -v docker > /dev/null; then
CONTAINER_CLI=$(command -v docker)
HEALTCHECK_PATH="{{.State.Health.Status}}"
PRIVILEGED_CLI=""
else
if command -v podman > /dev/null; then
CONTAINER_CLI=$(command -v podman)
HEALTCHECK_PATH="{{.State.Healthcheck.Status}}"
# Only use sudo for podman
Expand All @@ -13,6 +9,10 @@ else
else
PRIVILEGED_CLI=""
fi
else
CONTAINER_CLI=$(command -v docker)
HEALTCHECK_PATH="{{.State.Health.Status}}"
PRIVILEGED_CLI=""
fi

mysql() {
Expand Down Expand Up @@ -306,7 +306,7 @@ db2_11_5() {

db2_12_1() {
$PRIVILEGED_CLI $CONTAINER_CLI rm -f db2 || true
$PRIVILEGED_CLI $CONTAINER_CLI run --name db2 --privileged -e DB2INSTANCE=orm_test -e DB2INST1_PASSWORD=orm_test -e DBNAME=orm_test -e LICENSE=accept -e AUTOCONFIG=false -e ARCHIVE_LOGS=false -e TO_CREATE_SAMPLEDB=false -e REPODB=false -p 50000:50000 -d ${DB_IMAGE_DB2_11_5:-icr.io/db2_community/db2:12.1.2.0}
$PRIVILEGED_CLI $CONTAINER_CLI run --name db2 --privileged --platform=linux/amd64 -e DB2INSTANCE=orm_test -e DB2INST1_PASSWORD=orm_test -e DBNAME=orm_test -e LICENSE=accept -e AUTOCONFIG=false -e ARCHIVE_LOGS=false -e TO_CREATE_SAMPLEDB=false -e REPODB=false -e IS_OSXFS=true -e BLU=false -e ENABLE_ORACLE_COMPATIBILITY=false -e UPDATEAVAIL=NO -e PERSISTENT_HOME=false -e HADR_ENABLED=false -p 50000:50000 -d ${DB_IMAGE_DB2_12_1:-icr.io/db2_community/db2:12.1.2.0}
# Give the container some time to start
OUTPUT=
while [[ $OUTPUT != *"INSTANCE"* ]]; do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,21 @@ The Hibernate ORM Vector module contains support for mathematical vector types a
This is useful for AI/ML topics like vector similarity search and Retrieval-Augmented Generation (RAG).
The module comes with support for a special `vector` data type that essentially represents an array of bytes, floats, or doubles.

So far, both the PostgreSQL extension `pgvector` and the Oracle database 23ai+ `AI Vector Search` feature are supported, but in theory,
the vector specific functions could be implemented to work with every database that supports arrays.
Currently, the following databases are supported:

For further details, refer to the https://github.com/pgvector/pgvector#querying[pgvector documentation] or the https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/overview-node.html[AI Vector Search documentation].
* PostgreSQL 13+ through the https://github.com/pgvector/pgvector#querying[`pgvector` extension]
* https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/overview-node.html[Oracle database 23ai+]
* https://mariadb.com/docs/server/reference/sql-structure/vectors/vector-overview[MariaDB 11.7+]
* https://dev.mysql.com/doc/refman/9.4/en/vector-functions.html[MySQL 9.0+]

In theory, the vector-specific functions could be implemented to work with every database that supports arrays.

[WARNING]
====
Per the https://dev.mysql.com/doc/refman/9.4/en/vector-functions.html#function_distance[MySQL documentation],
the various vector distance functions for MySQL only work on MySQL cloud offerings like
https://dev.mysql.com/doc/heatwave/en/mys-hw-about-heatwave.html[HeatWave MySQL on OCI].
====

[[vector-module-setup]]
=== Setup
Expand All @@ -42,22 +53,32 @@ so no further configuration is necessary to make the features available.
[[vector-module-usage]]
==== Usage

Annotate a persistent attribute with `@JdbcTypeCode(SqlTypes.VECTOR)` and specify the vector length with `@Array(length = ...)`.
Annotate a persistent attribute with one of the various vector type codes `@JdbcTypeCode` and specify the vector length with `@Array(length = ...)`.
Possible vector type codes and the compatible Java types are:

* `@JdbcTypeCode(SqlTypes.VECTOR_BINARY)` for `byte[]`
* `@JdbcTypeCode(SqlTypes.VECTOR_INT8)` for `byte[]`
* `@JdbcTypeCode(SqlTypes.VECTOR_FLOAT16)` for `float[]`
* `@JdbcTypeCode(SqlTypes.VECTOR_FLOAT32)` for `float[]`
* `@JdbcTypeCode(SqlTypes.VECTOR_FLOAT64)` for `double[]`
* `@JdbcTypeCode(SqlTypes.VECTOR)` for `float[]`

Hibernate ORM also provides support for sparse vectors through dedicated Java types:

* `@JdbcTypeCode(SqlTypes.SPARSE_VECTOR_INT8)` for `SparseByteVector`
* `@JdbcTypeCode(SqlTypes.SPARSE_VECTOR_FLOAT32)` for `SparseFloatVector`
* `@JdbcTypeCode(SqlTypes.SPARSE_VECTOR_FLOAT64)` for `SparseDoubleVector`

[WARNING]
====
As Oracle AI Vector Search supports different types of elements (to ensure better performance and compatibility with embedding models), you can also use:

- `@JdbcTypeCode(SqlTypes.VECTOR_INT8)` for `byte[]`
- `@JdbcTypeCode(SqlTypes.VECTOR_FLOAT32)` for `float[]`
- `@JdbcTypeCode(SqlTypes.VECTOR_FLOAT64)` for `double[]`.
Vector data type support depends on native support of the underlying database.
====

[[vector-module-usage-example]]
====
[source, java, indent=0]
----
include::{example-dir-vector}/PGVectorTest.java[tags=usage-example]
include::{example-dir-vector}/FloatVectorTest.java[tags=usage-example]
----
====

Expand All @@ -77,14 +98,21 @@ Expressions of the vector type can be used with various vector functions.
| `euclidean_distance()` | Computes the https://en.wikipedia.org/wiki/Euclidean_distance[euclidean distance] between two vectors. Maps to the `<``-``>` operator for `pgvector` and maps to the
`vector_distance(v1, v2, EUCLIDEAN)` function for `Oracle AI Vector Search`.

| `euclidean_squared_distance()` | Computes the https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance[squared euclidean distance] between two vectors.

| `l2_distance()` | Alias for `euclidean_distance()`

| `l2_squared_distance()` | Alias for `euclidean_squared_distance()`

| `taxicab_distance()` | Computes the https://en.wikipedia.org/wiki/Taxicab_geometry[taxicab distance] between two vectors. Maps to `vector_distance(v1, v2, MANHATTAN)` function for `Oracle AI Vector Search`.

| `l1_distance()` | Alias for `taxicab_distance()`

| `hamming_distance()` | Computes the https://en.wikipedia.org/wiki/Hamming_distance[hamming distance] between two vectors. Maps to `vector_distance(v1, v2, HAMMING)` function for `Oracle AI Vector Search`.

| `jaccard_distance()` | Computes the https://en.wikipedia.org/wiki/Jaccard_index[jaccard distance] between two vectors. Maps to the `<``%``>` operator for `pgvector` and maps to the
`vector_distance(v1, v2, JACCARD)` function for `Oracle AI Vector Search`.

| `inner_product()` | Computes the https://en.wikipedia.org/wiki/Inner_product_space[inner product] between two vectors

| `negative_inner_product()` | Computes the negative inner product. Maps to the `<``#``>` operator for `pgvector` and maps to the
Expand All @@ -93,6 +121,14 @@ Expressions of the vector type can be used with various vector functions.
| `vector_dims()` | Determines the dimensions of a vector

| `vector_norm()` | Computes the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of a vector

| `l2_norm()` | Alias for `vector_norm()`

| `l2_normalize()` | Normalizes each component of a vector by dividing it with the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of the vector.

| `binary_quantize()` | Reduces a vector of size N to a binary vector with N bits, using 0 for values <= 0 and 1 for values > 0.

| `subvector()` | Creates a subvector from a given vector, a 1-based start index and a count.
|===

In addition to these special vector functions, it is also possible to use vectors with the following builtin `pgvector` operators:
Expand All @@ -113,7 +149,7 @@ which is `1 - inner_product( v1, v2 ) / ( vector_norm( v1 ) * vector_norm( v2 )
====
[source, java, indent=0]
----
include::{example-dir-vector}/PGVectorTest.java[tags=cosine-distance-example]
include::{example-dir-vector}/FloatVectorTest.java[tags=cosine-distance-example]
----
====

Expand All @@ -128,7 +164,22 @@ The `l2_distance()` function is an alias.
====
[source, java, indent=0]
----
include::{example-dir-vector}/PGVectorTest.java[tags=euclidean-distance-example]
include::{example-dir-vector}/FloatVectorTest.java[tags=euclidean-distance-example]
----
====

[[vector-module-functions-euclidean-squared-distance]]
===== `euclidean_squared_distance()` and `l2_squared_distance()`

Computes the https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance[squared euclidean distance] between two vectors,
which is `sum( (v1_i - v2_i)^2 )`, just like the regular `euclidean_distance`, but without the `sqrt`.
The `l2_squared_distance()` function is an alias.

[[vector-module-functions-euclidean-squared-distance-example]]
====
[source, java, indent=0]
----
include::{example-dir-vector}/FloatVectorTest.java[tags=euclidean-squared-distance-example]
----
====

Expand All @@ -143,7 +194,37 @@ The `l1_distance()` function is an alias.
====
[source, java, indent=0]
----
include::{example-dir-vector}/PGVectorTest.java[tags=taxicab-distance-example]
include::{example-dir-vector}/FloatVectorTest.java[tags=taxicab-distance-example]
----
====

[[vector-module-functions-hamming-distance]]
===== `hamming_distance()`

Computes the https://en.wikipedia.org/wiki/Hamming_distance[hamming distance] between two binary vectors,
which is `bit_count(v1 ^ v2)` i.e. the amount of bits where two vectors differ.
Maps to the `<``~``>` operator for `pgvector`.

[[vector-module-functions-taxicab-distance-example]]
====
[source, java, indent=0]
----
include::{example-dir-vector}/BinaryVectorTest.java[tags=hamming-distance-example]
----
====

[[vector-module-functions-jaccard-distance]]
===== `jaccard_distance()`

Computes the https://en.wikipedia.org/wiki/Jaccard_index[jaccard distance] between two binary vectors,
which is `1 - bit_count(v1 & v2) / bit_count(v1 | v2)`.
Maps to the `<``%``>` operator for `pgvector`.

[[vector-module-functions-taxicab-distance-example]]
====
[source, java, indent=0]
----
include::{example-dir-vector}/BinaryVectorTest.java[tags=jaccard-distance-example]
----
====

Expand All @@ -158,7 +239,7 @@ and the `inner_product()` function as well, but multiplies the result time `-1`.
====
[source, java, indent=0]
----
include::{example-dir-vector}/PGVectorTest.java[tags=inner-product-example]
include::{example-dir-vector}/FloatVectorTest.java[tags=inner-product-example]
----
====

Expand All @@ -171,24 +252,63 @@ Determines the dimensions of a vector.
====
[source, java, indent=0]
----
include::{example-dir-vector}/PGVectorTest.java[tags=vector-dims-example]
include::{example-dir-vector}/FloatVectorTest.java[tags=vector-dims-example]
----
====

[[vector-module-functions-vector-norm]]
===== `vector_norm()`
===== `vector_norm()` and `l2_norm()`

Computes the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of a vector,
which is `sqrt( sum( v_i^2 ) )`.
The `l2_norm()` function is an alias.

[[vector-module-functions-vector-norm-example]]
====
[source, java, indent=0]
----
include::{example-dir-vector}/PGVectorTest.java[tags=vector-norm-example]
include::{example-dir-vector}/FloatVectorTest.java[tags=vector-norm-example]
----
====

[[vector-module-functions-l2-normalize]]
===== `l2_normalize()`

Normalizes each component of a vector by dividing it with the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of the vector.

[[vector-module-functions-l2-normalize-example]]
====
[source, java, indent=0]
----
include::{example-dir-vector}/FloatVectorTest.java[tags=l2-normalize-example]
----
====

[[vector-module-functions-binary-quantize]]
===== `binary_quantize()`

Reduces a vector of size N to a binary vector with N bits, using 0 for values <= 0 and 1 for values > 0.

[[vector-module-functions-binary-quantize-example]]
====
[source, java, indent=0]
----
include::{example-dir-vector}/FloatVectorTest.java[tags=binary-quantize-example]
----
====

[[vector-module-functions-subvector]]
===== `binary_quantize()`

Creates a subvector from a given vector, a 1-based start index and a count.

[[vector-module-functions-subvector-example]]
====
[source, java, indent=0]
----
include::{example-dir-vector}/FloatVectorTest.java[tags=subvector-example]
----
====



Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import java.sql.ResultSet;
import java.sql.SQLException;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.hibernate.dialect.Dialect;
import org.hibernate.engine.jdbc.Size;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.descriptor.ValueBinder;
Expand All @@ -35,6 +37,7 @@ public class GaussDBCastingInetJdbcType implements JdbcType {
@Override
public void appendWriteExpression(
String writeExpression,
@Nullable Size size,
SqlAppender appender,
Dialect dialect) {
appender.append( "cast(" );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import java.sql.ResultSet;
import java.sql.SQLException;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.hibernate.dialect.Dialect;
import org.hibernate.engine.jdbc.Size;
import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.hibernate.metamodel.mapping.JdbcMappingContainer;
import org.hibernate.sql.ast.SqlAstTranslator;
Expand Down Expand Up @@ -77,6 +79,7 @@ public JdbcMappingContainer getExpressionType() {
@Override
public void appendWriteExpression(
String writeExpression,
@Nullable Size size,
SqlAppender appender,
Dialect dialect) {
appender.append( '(' );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
*/
package org.hibernate.community.dialect;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.hibernate.dialect.Dialect;
import org.hibernate.engine.jdbc.Size;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.type.descriptor.jdbc.JdbcType;
import org.hibernate.type.descriptor.jdbc.JsonArrayJdbcType;
Expand All @@ -27,6 +29,7 @@ public GaussDBCastingJsonArrayJdbcType(JdbcType elementJdbcType, boolean jsonb)
@Override
public void appendWriteExpression(
String writeExpression,
@Nullable Size size,
SqlAppender appender,
Dialect dialect) {
appender.append( "cast(" );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
*/
package org.hibernate.community.dialect;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.hibernate.dialect.Dialect;
import org.hibernate.engine.jdbc.Size;
import org.hibernate.metamodel.mapping.EmbeddableMappingType;
import org.hibernate.metamodel.spi.RuntimeModelCreationContext;
import org.hibernate.sql.ast.spi.SqlAppender;
Expand Down Expand Up @@ -46,6 +48,7 @@ public AggregateJdbcType resolveAggregateJdbcType(
@Override
public void appendWriteExpression(
String writeExpression,
@Nullable Size size,
SqlAppender appender,
Dialect dialect) {
appender.append( "cast(" );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import java.sql.PreparedStatement;
import java.sql.SQLException;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.hibernate.boot.model.naming.Identifier;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.type.AbstractPostgreSQLStructJdbcType;
import org.hibernate.engine.jdbc.Size;
import org.hibernate.metamodel.mapping.EmbeddableMappingType;
import org.hibernate.metamodel.spi.RuntimeModelCreationContext;
import org.hibernate.sql.ast.spi.SqlAppender;
Expand Down Expand Up @@ -59,6 +61,7 @@ public AggregateJdbcType resolveAggregateJdbcType(
@Override
public void appendWriteExpression(
String writeExpression,
@Nullable Size size,
SqlAppender appender,
Dialect dialect) {
appender.append( "cast(" );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ public class OracleTypes {
public static final int VECTOR_INT8 = -106;
public static final int VECTOR_FLOAT32 = -107;
public static final int VECTOR_FLOAT64 = -108;
public static final int VECTOR_BINARY = -109;
}
Loading
Loading