Skip to content

Commit 363b859

Browse files
committed
testing with vector
1 parent d967738 commit 363b859

File tree

9 files changed

+3145
-1
lines changed

9 files changed

+3145
-1
lines changed

astra-sdk-samples/sample-quickstart/src/test/java/com/datastax/astra/sdk/quickstart/db/AstraVectorSearchPreviewTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ private Optional<Product> findProductById(String productId) {
9696

9797
private List<Product> findAllSimilarProducts(Product orginal) {
9898
return astraClient.cqlSession().execute(SimpleStatement
99-
.builder("SELECT * FROM pet_supply_vectors ORDER BY product_vector ANN OF ? LIMIT 2;")
99+
.builder("SELECT * FROM pet_supply_vectors " +
100+
"ORDER BY product_vector ANN OF ? LIMIT 2;")
100101
.addPositionalValue(orginal.vector)
101102
.build())
102103
.all()

astra-sdk-vector/pom.xml

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
5+
<modelVersion>4.0.0</modelVersion>
6+
<artifactId>astra-sdk-vector</artifactId>
7+
<name> + astra-sdk-vector</name>
8+
9+
<parent>
10+
<groupId>com.datastax.astra</groupId>
11+
<artifactId>astra-sdk-parent</artifactId>
12+
<version>0.6.8-SNAPSHOT</version>
13+
</parent>
14+
15+
<properties>
16+
<openai-java.version>0.15.0</openai-java.version>
17+
<langchain-java.version>0.1.11</langchain-java.version>
18+
<maven.compiler.target>17</maven.compiler.target>
19+
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
20+
</properties>
21+
22+
<dependencies>
23+
<dependency>
24+
<groupId>org.projectlombok</groupId>
25+
<artifactId>lombok</artifactId>
26+
<scope>provided</scope>
27+
</dependency>
28+
29+
<dependency>
30+
<groupId>io.github.hamawhitegg</groupId>
31+
<artifactId>langchain-core</artifactId>
32+
<version>${langchain-java.version}</version>
33+
</dependency>
34+
35+
<dependency>
36+
<groupId>com.theokanning.openai-gpt3-java</groupId>
37+
<artifactId>service</artifactId>
38+
<version>${openai-java.version}</version>
39+
</dependency>
40+
41+
<dependency>
42+
<groupId>com.datastax.astra</groupId>
43+
<artifactId>astra-sdk</artifactId>
44+
<version>${project.version}</version>
45+
<exclusions>
46+
<exclusion>
47+
<groupId>com.datastax.astra</groupId>
48+
<artifactId>astra-sdk-pulsar</artifactId>
49+
</exclusion>
50+
</exclusions>
51+
</dependency>
52+
53+
<dependency>
54+
<groupId>com.datastax.oss</groupId>
55+
<artifactId>java-driver-query-builder</artifactId>
56+
</dependency>
57+
58+
<dependency>
59+
<groupId>org.junit.jupiter</groupId>
60+
<artifactId>junit-jupiter-engine</artifactId>
61+
<scope>test</scope>
62+
</dependency>
63+
<dependency>
64+
<groupId>org.slf4j</groupId>
65+
<artifactId>slf4j-api</artifactId>
66+
</dependency>
67+
<dependency>
68+
<groupId>ch.qos.logback</groupId>
69+
<artifactId>logback-classic</artifactId>
70+
</dependency>
71+
<dependency>
72+
<groupId>org.checkerframework</groupId>
73+
<artifactId>checker-qual</artifactId>
74+
<version>3.33.0</version>
75+
<scope>test</scope>
76+
</dependency>
77+
</dependencies>
78+
</project>
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package com.dtsx.astra.sdk.vector;
2+
3+
import com.datastax.oss.driver.api.core.CqlSession;
4+
5+
public abstract class AbstractVectorTable {
6+
7+
public static final String SAI_INDEX_CLASSNAME = "org.apache.cassandra.index.sai.StorageAttachedIndex";
8+
9+
/** Session to Cassandra. */
10+
protected final CqlSession cqlSession;
11+
12+
protected final String keyspaceName;
13+
14+
protected final String tableName;
15+
16+
public AbstractVectorTable(CqlSession session, String keyspaceName, String tableName) {
17+
this.cqlSession = session;
18+
this.keyspaceName = keyspaceName;
19+
this.tableName = tableName;
20+
}
21+
22+
protected void delete() {
23+
cqlSession.execute("DROP TABLE IF EXISTS " + tableName);
24+
}
25+
26+
protected void clear() {
27+
cqlSession.execute("TRUNCATE TABLE " + tableName);
28+
}
29+
30+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package com.dtsx.astra.sdk.vector;
2+
3+
import com.datastax.oss.driver.api.core.cql.Row;
4+
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
5+
import com.datastax.oss.driver.api.core.data.CqlVector;
6+
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
7+
import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert;
8+
import jakarta.validation.constraints.NotNull;
9+
import lombok.Data;
10+
11+
import java.util.HashMap;
12+
import java.util.List;
13+
import java.util.Map;
14+
import java.util.stream.Collectors;
15+
16+
import static com.dtsx.astra.sdk.vector.MetadataVectorCassandraTable.ATTRIBUTES_BLOB;
17+
import static com.dtsx.astra.sdk.vector.MetadataVectorCassandraTable.BODY_BLOB;
18+
import static com.dtsx.astra.sdk.vector.MetadataVectorCassandraTable.METADATA_S;
19+
import static com.dtsx.astra.sdk.vector.MetadataVectorCassandraTable.ROW_ID;
20+
import static com.dtsx.astra.sdk.vector.MetadataVectorCassandraTable.VECTOR;
21+
22+
@Data
23+
public class MetadataVectorCassandraRecord {
24+
25+
private String rowId;
26+
27+
private String attributes;
28+
29+
private String body;
30+
31+
private Map<String, String> metadata = new HashMap<>();
32+
33+
private List<Float> vector;
34+
35+
@SuppressWarnings("unchecked")
36+
public static MetadataVectorCassandraRecord fromRow(Row cqlRow) {
37+
if (cqlRow == null) return null;
38+
MetadataVectorCassandraRecord record = new MetadataVectorCassandraRecord();
39+
record.setRowId(cqlRow.getString(ROW_ID));
40+
record.setAttributes(cqlRow.getString(ATTRIBUTES_BLOB));
41+
record.setBody(cqlRow.getString(BODY_BLOB));
42+
record.setMetadata(cqlRow.getMap(METADATA_S, String.class, String.class));
43+
record.setVector(((CqlVector<Float>) cqlRow.getObject(VECTOR)).stream().collect(Collectors.toList()));
44+
return record;
45+
}
46+
47+
public SimpleStatement insertStatement(@NotNull String keyspaceName, @NotNull String tableName) {
48+
if (rowId == null) throw new IllegalStateException("Row Id cannot be null");
49+
if (vector == null) throw new IllegalStateException("Vector cannot be null");
50+
RegularInsert regularInser = QueryBuilder
51+
.insertInto(keyspaceName, tableName)
52+
.value(ROW_ID, QueryBuilder.literal(rowId))
53+
.value(VECTOR, QueryBuilder.literal(CqlVector.newInstance(vector)));
54+
if (attributes != null) {
55+
regularInser = regularInser.value(ATTRIBUTES_BLOB, QueryBuilder.literal(attributes));
56+
}
57+
if (body != null) {
58+
regularInser = regularInser.value(BODY_BLOB, QueryBuilder.literal(body));
59+
}
60+
if (metadata != null && !metadata.isEmpty()) {
61+
regularInser = regularInser.value(METADATA_S, QueryBuilder.literal(metadata));
62+
}
63+
return regularInser.build();
64+
}
65+
66+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package com.dtsx.astra.sdk.vector;
2+
3+
import com.datastax.oss.driver.api.core.CqlSession;
4+
import com.datastax.oss.driver.api.core.cql.ResultSet;
5+
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
6+
import com.datastax.oss.driver.api.core.data.CqlVector;
7+
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
8+
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
9+
import lombok.Getter;
10+
import lombok.extern.slf4j.Slf4j;
11+
12+
import java.util.List;
13+
import java.util.Map;
14+
import java.util.stream.Collectors;
15+
16+
@Slf4j @Getter
17+
public class MetadataVectorCassandraTable extends AbstractVectorTable {
18+
19+
/**
20+
* Rable Structure
21+
*/
22+
public static final String ROW_ID = "row_id";
23+
public static final String ATTRIBUTES_BLOB = "attributes_blob";
24+
public static final String BODY_BLOB = "body_blob";
25+
public static final String METADATA_S = "metadata_s";
26+
public static final String VECTOR = "vector";
27+
28+
private final int vectorDimension;
29+
30+
public MetadataVectorCassandraTable(CqlSession session, String keyspaceName, String tableName, int vectorDimension) {
31+
super(session, keyspaceName, tableName);
32+
this.vectorDimension = vectorDimension;
33+
}
34+
35+
public void createIfNotExist() {
36+
// Create Table
37+
String cql = "CREATE TABLE IF NOT EXISTS " + tableName + " (" +
38+
ROW_ID + " text, " +
39+
ATTRIBUTES_BLOB + " text, " +
40+
BODY_BLOB + " text, " +
41+
METADATA_S + " map<text, text>, " +
42+
VECTOR + " vector<float, " + vectorDimension + ">, " +
43+
"PRIMARY KEY (" +
44+
ROW_ID + ")" +
45+
")";
46+
cqlSession.execute(cql);
47+
log.info("+ Table '{}' has been created (if needed).", tableName);
48+
// Create Vector Index
49+
cqlSession.execute(SchemaBuilder.createIndex("idx_vector_" + tableName)
50+
.ifNotExists()
51+
.custom(SAI_INDEX_CLASSNAME)
52+
.onTable(tableName)
53+
.andColumn(VECTOR)
54+
.build());
55+
log.info("+ Index '{}' has been created (if needed).", "idx_vector_" + tableName);
56+
// Create Metadata Index
57+
cqlSession.execute(SchemaBuilder.createIndex("eidx_metadata_s_" + tableName)
58+
.ifNotExists()
59+
.custom(SAI_INDEX_CLASSNAME)
60+
.onTable(tableName)
61+
.andColumnEntries(METADATA_S)
62+
.build());
63+
log.info("+ Index '{}' has been created (if needed).", "eidx_metadata_s_" + tableName);
64+
}
65+
66+
public void put(String rowId, String bodyBlob, String attributesBlob, Map<String, String> metadata, List<Float> vector) {
67+
cqlSession.execute(QueryBuilder.insertInto(keyspaceName, tableName)
68+
.value(ROW_ID, QueryBuilder.literal(rowId))
69+
.value(BODY_BLOB, QueryBuilder.literal(bodyBlob))
70+
.value(ATTRIBUTES_BLOB, QueryBuilder.literal(attributesBlob))
71+
.value(METADATA_S, QueryBuilder.literal(metadata))
72+
.value(VECTOR, QueryBuilder.literal(CqlVector.newInstance(vector)))
73+
.build());
74+
}
75+
76+
public void put(MetadataVectorCassandraRecord row) {
77+
cqlSession.execute(row.insertStatement(keyspaceName, tableName));
78+
}
79+
80+
81+
82+
public List<MetadataVectorCassandraRecord> ann_search(List<Float> vector, int recordCount, Map<String, String > metatadata) {
83+
ResultSet rs = cqlSession.execute( SimpleStatement.builder(
84+
"SELECT * FROM " + tableName
85+
+ " ORDER BY vector ANN OF ? LIMIT ?")
86+
.addPositionalValue(CqlVector.newInstance(vector))
87+
.addPositionalValue(recordCount)
88+
.build());
89+
return rs.all().stream()
90+
.map(MetadataVectorCassandraRecord::fromRow)
91+
.collect(Collectors.toList());
92+
}
93+
94+
}

0 commit comments

Comments
 (0)