Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package com.microsoft.semantickernel.data.jdbc.oracle;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.microsoft.semantickernel.data.filter.FilterClause;
import com.microsoft.semantickernel.data.jdbc.JDBCVectorStore;
import com.microsoft.semantickernel.data.jdbc.JDBCVectorStoreOptions;
import com.microsoft.semantickernel.data.jdbc.JDBCVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.data.jdbc.oracle.OracleVectorStoreQueryProvider.StringTypeMapping;
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchFilter;
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResults;
import com.microsoft.semantickernel.data.vectorstorage.VectorStoreRecordCollection;
Expand All @@ -12,17 +14,21 @@
import com.microsoft.semantickernel.data.vectorstorage.annotations.VectorStoreRecordVector;
import com.microsoft.semantickernel.data.vectorstorage.definition.DistanceFunction;
import com.microsoft.semantickernel.data.vectorstorage.definition.IndexKind;
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordDataField;
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordKeyField;
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordVectorField;
import com.microsoft.semantickernel.data.vectorstorage.options.GetRecordOptions;
import com.microsoft.semantickernel.data.vectorstorage.options.VectorSearchOptions;
import com.microsoft.semantickernel.exceptions.SKException;

import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -33,6 +39,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand Down Expand Up @@ -86,6 +93,156 @@ void testUseCollectionVec() {
collection.deleteCollectionAsync().block();
}

// Test index types
@Nested
class HNSWIndexTests {
@Test
void testHNSWIndexIsCreatedSuccessfully() throws Exception {
VectorStoreRecordKeyField keyField = VectorStoreRecordKeyField.builder()
.withName("id")
.withStorageName("id")
.withFieldType(String.class)
.build();

VectorStoreRecordDataField dummyField = VectorStoreRecordDataField.builder()
.withName("dummy")
.withStorageName("dummy")
.withFieldType(String.class)
.isFilterable(false)
.build();

VectorStoreRecordVectorField hnswVector= VectorStoreRecordVectorField.builder()
.withName("hnsw")
.withStorageName("hnsw")
.withFieldType(List.class)
.withDimensions(8)
.withDistanceFunction(DistanceFunction.COSINE_SIMILARITY)
.withIndexKind(IndexKind.HNSW)
.build();

VectorStoreRecordDefinition definition = VectorStoreRecordDefinition.fromFields(
Arrays.asList(keyField, dummyField, hnswVector)
);

OracleVectorStoreQueryProvider queryProvider = OracleVectorStoreQueryProvider.builder()
.withDataSource(DATA_SOURCE)
.build();

JDBCVectorStore vectorStore = JDBCVectorStore.builder()
.withDataSource(DATA_SOURCE)
.withOptions(JDBCVectorStoreOptions.builder()
.withQueryProvider(queryProvider)
.build())
.build();

String collectionName = "skhotels_hnsw";
VectorStoreRecordCollection<String, Object> collection =
vectorStore.getCollection(collectionName,
JDBCVectorStoreRecordCollectionOptions.<Object>builder()
.withRecordClass(Object.class)
.withRecordDefinition(definition)
.build());

// create collection
collection.createCollectionAsync().block();

String expectedIndexName = hnswVector.getEffectiveStorageName().toUpperCase() + "_VECTOR_INDEX";

// check if index exist
try (Connection conn = DATA_SOURCE.getConnection();
PreparedStatement stmt = conn.prepareStatement(
"SELECT COUNT(*) FROM USER_INDEXES WHERE INDEX_NAME=?")) {
stmt.setString(1, expectedIndexName);
ResultSet rs = stmt.executeQuery();
rs.next();
int count = rs.getInt(1);

assertEquals(1, count, "hnsw vector index should have been created");
} finally {
// clean up
try (Connection conn = DATA_SOURCE.getConnection();
Statement stmt = conn.createStatement()) {
stmt.executeUpdate("DROP TABLE " + "SKCOLLECTION_" + collectionName);
}
}
}
}

@Nested
class UndefinedIndexTests {
@Test
void testNoIndexIsCreatedForUndefined() throws Exception {
// create key field
VectorStoreRecordKeyField keyField = VectorStoreRecordKeyField.builder()
.withName("id")
.withStorageName("id")
.withFieldType(String.class)
.build();

// create vector field, set IndexKind to UNDEFINED
VectorStoreRecordVectorField undefinedVector= VectorStoreRecordVectorField.builder()
.withName("undef")
.withStorageName("undef")
.withFieldType(List.class)
.withDimensions(8)
.withDistanceFunction(DistanceFunction.COSINE_SIMILARITY)
.withIndexKind(IndexKind.UNDEFINED)
.build();

VectorStoreRecordDataField dummyField = VectorStoreRecordDataField.builder()
.withName("dummy")
.withStorageName("dummy")
.withFieldType(String.class)
.isFilterable(false)
.build();

VectorStoreRecordDefinition definition = VectorStoreRecordDefinition.fromFields(
Arrays.asList(keyField, dummyField, undefinedVector)
);

OracleVectorStoreQueryProvider queryProvider = OracleVectorStoreQueryProvider.builder()
.withDataSource(DATA_SOURCE)
.build();

JDBCVectorStore vectorStore = JDBCVectorStore.builder()
.withDataSource(DATA_SOURCE)
.withOptions(JDBCVectorStoreOptions.builder()
.withQueryProvider(queryProvider)
.build())
.build();

String collectionName = "skhotels_undefined";
VectorStoreRecordCollection<String, Object> collection =
vectorStore.getCollection(collectionName,
JDBCVectorStoreRecordCollectionOptions.<Object>builder()
.withRecordClass(Object.class)
.withRecordDefinition(definition)
.build());

// create collection
collection.createCollectionAsync().block();

// check if index exist
String expectedIndexName = undefinedVector.getEffectiveStorageName().toUpperCase() + "_VETCOR_INDEX";
try (Connection conn = DATA_SOURCE.getConnection();
PreparedStatement stmt = conn.prepareStatement(
"SELECT COUNT(*) FROM USER_INDEXES WHERE INDEX_NAME = ?")) {
stmt.setString(1, expectedIndexName);
ResultSet rs = stmt.executeQuery();
rs.next();
int count = rs.getInt(1);

assertEquals(0,count,"Vector index should not be created for IndexKind.UNDEFINED");
} finally {
// clean up
try (Connection conn = DATA_SOURCE.getConnection();
Statement stmt = conn.createStatement()) {
stmt.executeUpdate("DROP TABLE " + "SKCOLLECTION_" + collectionName);
}
}
}
}

// Test corner-case
@Test
void testUseCLOB() {
Expand Down Expand Up @@ -134,6 +291,32 @@ void testClobLongText() {
collection.deleteCollectionAsync().block();
}

@Test
void testSearchClob() {
VectorStoreRecordCollection<String, DummyRecordForCLOB> collection =
createCollection(
"clob_long_text",
DummyRecordForCLOB.class,
OracleVectorStoreQueryProvider.StringTypeMapping.USE_CLOB);

String longText = String.join("", java.util.Collections.nCopies(6000, "a"));
DummyRecordForCLOB d1 = new DummyRecordForCLOB("small", longText, vec(2));
DummyRecordForCLOB d2 = new DummyRecordForCLOB("big", "short", vec(1));

collection.upsertBatchAsync(Arrays.asList(d1, d2), null).block();

VectorSearchResults<DummyRecordForCLOB> results =
collection.searchAsync(vec(1),
VectorSearchOptions.builder()
.build()
).block();
assertEquals(2, results.getTotalCount());
assertEquals("big", results.getResults().get(0).getRecord().getId());
assertEquals("small", results.getResults().get(1).getRecord().getId());

collection.deleteCollectionAsync().block();
}

@Test
void testMultipleFilter() {
VectorStoreRecordCollection<String, DummyRecordForMultipleFilter> collection =
Expand Down Expand Up @@ -246,6 +429,22 @@ void testSkipAndTop() {
collection.deleteCollectionAsync().block();
}

@Test
void testSearchVectorField_throws() {
VectorStoreRecordCollection<String, DummyRecord> collection =
createCollection(
"search_vector_field_throws",
DummyRecord.class,
null);

SKException ex = assertThrows(SKException.class, ()-> collection.searchAsync(
null, VectorSearchOptions.builder().withVectorFieldName("notexist").build()).block());
System.out.println(ex.getMessage());
assertTrue(ex.getMessage().contains("Field not found: notexist"));

collection.deleteCollectionAsync().block();
}

// corner case for OracleVectorStoreRecordMapper
@Test
void testMapRecordToStorageModel_throws() {
Expand All @@ -272,6 +471,95 @@ OracleVectorStoreRecordMapper.<DummyRecord> builder()
assertEquals("Not implemented", ex.getMessage());
}

// corner case for OracleVectorStoreQueryProvider
@Test
void testUnsupportedFilterClause() {
OracleVectorStoreQueryProvider provider = OracleVectorStoreQueryProvider.builder()
.withDataSource(DATA_SOURCE)
.build();
VectorSearchFilter filter = new VectorSearchFilter(Arrays.asList(new DummyFilterClause()));

SKException ex = assertThrows(SKException.class, () -> provider.getFilterParameters(filter));
assertTrue(ex.getMessage().contains("Unsupported filter clause type 'DummyFilterClause'."));
}

// corner case for OracleVectorStoreQueryProvider#Builder
@Test
void testOracleQueryProviderBuilder() throws NoSuchFieldException, IllegalAccessException {
OracleVectorStoreQueryProvider provider = OracleVectorStoreQueryProvider.builder()
.withDataSource(DATA_SOURCE)
.withCollectionsTable("myCollection")
.withPrefixForCollectionTables("myPrefix_")
.build();

java.lang.reflect.Field collectionsTable =
provider.getClass().getDeclaredField("collectionsTable");
collectionsTable.setAccessible(true);
assertEquals("myCollection", collectionsTable.get(provider));

java.lang.reflect.Field prefix =
provider.getClass().getSuperclass().getDeclaredField("prefixForCollectionTables");
prefix.setAccessible(true);
assertEquals("myPrefix_", prefix.get(provider));
}

@Test
void testOracleQueryProviderBuilder_withDefaultVarcharSize() {
OracleVectorStoreQueryProvider provider = OracleVectorStoreQueryProvider.builder()
.withDataSource(DATA_SOURCE)
.withDefaultVarcharSize(1234)
.build();

JDBCVectorStore vectorStore = JDBCVectorStore.builder()
.withDataSource(DATA_SOURCE)
.withOptions(JDBCVectorStoreOptions.builder()
.withQueryProvider(provider)
.build())
.build();

VectorStoreRecordCollection<String, DummyRecordForCLOB> collection =
vectorStore.getCollection("with_default_varchar_size",
JDBCVectorStoreRecordCollectionOptions.<DummyRecordForCLOB>builder()
.withRecordClass(DummyRecordForCLOB.class)
.build()).createCollectionAsync().block();

try (Connection c = DATA_SOURCE.getConnection();
PreparedStatement st = c.prepareStatement(
"SELECT DATA_TYPE, DATA_LENGTH FROM USER_TAB_COLUMNS " +
"WHERE TABLE_NAME = 'SKCOLLECTION_WITH_DEFAULT_VARCHAR_SIZE' " +
"AND COLUMN_NAME = 'DESCRIPTION'")) {
ResultSet rs = st.executeQuery();
rs.next();
assertEquals("VARCHAR2", rs.getString("DATA_TYPE"));
assertEquals(1234, rs.getInt("DATA_LENGTH"));

} catch (SQLException e) {
throw new RuntimeException(e);
} finally {
collection.deleteCollectionAsync().block();
}
}

@Test
void testOracleQueryProviderBuilder_withObjectMapper()
throws NoSuchFieldException, IllegalAccessException {
ObjectMapper customerMapper = new ObjectMapper();
customerMapper.enable(SerializationFeature.INDENT_OUTPUT);

OracleVectorStoreQueryProvider provider = OracleVectorStoreQueryProvider.builder()
.withDataSource(DATA_SOURCE)
.withObjectMapper(customerMapper)
.build();

java.lang.reflect.Field objectMapper =
provider.getClass().getDeclaredField("objectMapper");
objectMapper.setAccessible(true);
ObjectMapper actual = (ObjectMapper)objectMapper.get(provider);
assertSame(customerMapper, actual);
assertTrue(actual.isEnabled(SerializationFeature.INDENT_OUTPUT));
}


private <T> VectorStoreRecordCollection<String, T> createCollection(
String collectionName,
Class<T> recordClass,
Expand Down Expand Up @@ -475,4 +763,6 @@ public float[] getVec() {
return vec;
}
}

private static class DummyFilterClause implements FilterClause {}
}
Loading
Loading