Skip to content

Commit e5d2280

Browse files
committed
OK
1 parent 363b859 commit e5d2280

19 files changed

+1950
-450
lines changed

astra-sdk-vector/pom.xml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
<properties>
1616
<openai-java.version>0.15.0</openai-java.version>
17-
<langchain-java.version>0.1.11</langchain-java.version>
17+
<langchain4j.version>0.22.0</langchain4j.version>
1818
<maven.compiler.target>17</maven.compiler.target>
1919
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
2020
</properties>
@@ -27,9 +27,9 @@
2727
</dependency>
2828

2929
<dependency>
30-
<groupId>io.github.hamawhitegg</groupId>
31-
<artifactId>langchain-core</artifactId>
32-
<version>${langchain-java.version}</version>
30+
<groupId>dev.langchain4j</groupId>
31+
<artifactId>langchain4j</artifactId>
32+
<version>${langchain4j.version}</version>
3333
</dependency>
3434

3535
<dependency>
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package com.dtsx.astra.sdk.cassio;
2+
3+
import com.datastax.oss.driver.api.core.CqlSession;
4+
import com.datastax.oss.driver.api.core.cql.Row;
5+
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
6+
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
7+
8+
import java.util.concurrent.CompletableFuture;
9+
10+
/**
11+
* Abstract class for table management at Cassandra level.
12+
*/
13+
public abstract class AbstractCassandraTable<RECORD> {
14+
15+
/**
16+
* Class needed to create a SAI Index.
17+
*/
18+
public static final String SAI_INDEX_CLASSNAME = "org.apache.cassandra.index.sai.StorageAttachedIndex";
19+
20+
/**
21+
* Table Structure.
22+
*/
23+
public static final String PARTITION_ID = "partition_id";
24+
public static final String ROW_ID = "row_id";
25+
public static final String ATTRIBUTES_BLOB = "attributes_blob";
26+
public static final String BODY_BLOB = "body_blob";
27+
public static final String METADATA_S = "metadata_s";
28+
public static final String VECTOR = "vector";
29+
public static final String COLUMN_SIMILARITY = "similarity";
30+
31+
/**
32+
* Default Number of item retrieved
33+
*/
34+
public static final int DEFAULT_RECORD_COUNT = 4;
35+
36+
/** Session to Cassandra. */
37+
protected final CqlSession cqlSession;
38+
39+
/** Destination keyspace. */
40+
protected final String keyspaceName;
41+
42+
/** Destination table. */
43+
protected final String tableName;
44+
45+
/**
46+
* Default cosntructor.
47+
*
48+
* @param session
49+
* cassandra session
50+
* @param keyspaceName
51+
* keyspace
52+
* @param tableName
53+
* table Name
54+
*/
55+
public AbstractCassandraTable(CqlSession session, String keyspaceName, String tableName) {
56+
this.cqlSession = session;
57+
this.keyspaceName = keyspaceName;
58+
this.tableName = tableName;
59+
}
60+
61+
/**
62+
* Create table if not exist.
63+
*/
64+
public abstract void createSchema();
65+
66+
/**
67+
* Upsert a row of the table.
68+
*
69+
* @param row
70+
* current row
71+
*/
72+
public abstract void put(RECORD row);
73+
74+
/**
75+
* Should be table to map from a Cassandra row to a record.
76+
*
77+
* @param row
78+
* current cassandra row
79+
* @return
80+
* current record
81+
*/
82+
public abstract RECORD mapRow(Row row);
83+
84+
/**
85+
* Insert a row asynchronously.
86+
*
87+
* @param inputRow
88+
* current row
89+
* @return
90+
* output
91+
*/
92+
public CompletableFuture<Void> putAsync(final RECORD inputRow) {
93+
return CompletableFuture.runAsync(() -> put(inputRow));
94+
}
95+
96+
/**
97+
* Delete the table.
98+
*/
99+
public void delete() {
100+
cqlSession.execute(SchemaBuilder
101+
.dropTable(keyspaceName, tableName)
102+
.ifExists()
103+
.build());
104+
}
105+
106+
/**
107+
* Empty a table
108+
*/
109+
public void clear() {
110+
cqlSession.execute(QueryBuilder
111+
.truncate(keyspaceName, tableName)
112+
.build());
113+
}
114+
115+
}
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
package com.dtsx.astra.sdk.cassio;
2+
3+
import com.datastax.oss.driver.api.core.CqlSession;
4+
import com.datastax.oss.driver.api.core.cql.BatchStatement;
5+
import com.datastax.oss.driver.api.core.cql.BatchStatementBuilder;
6+
import com.datastax.oss.driver.api.core.cql.BatchType;
7+
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
8+
import com.datastax.oss.driver.api.core.cql.Row;
9+
import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder;
10+
import com.datastax.oss.driver.api.core.type.DataTypes;
11+
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
12+
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
13+
import lombok.AllArgsConstructor;
14+
import lombok.Data;
15+
import lombok.NoArgsConstructor;
16+
import lombok.NonNull;
17+
import lombok.extern.slf4j.Slf4j;
18+
19+
import java.util.List;
20+
import java.util.Optional;
21+
import java.util.UUID;
22+
import java.util.stream.Collectors;
23+
24+
/**
25+
* Table representing persistence for LangChain operations
26+
*/
27+
@Slf4j
28+
public class ClusteredCassandraTable extends AbstractCassandraTable<ClusteredCassandraTable.Record> {
29+
30+
/**
31+
* Prepared statements
32+
*/
33+
private final PreparedStatement findPartitionStatement;
34+
private final PreparedStatement deletePartitionStatement;
35+
private final PreparedStatement deleteRowStatement;
36+
private final PreparedStatement insertRowStatement;
37+
private final PreparedStatement findRowStatement;
38+
39+
/**
40+
* Constructor with the mandatory parameters.
41+
*
42+
* @param session
43+
* cassandra Session
44+
* @param keyspaceName
45+
* keyspace name
46+
* @param tableName
47+
* table name
48+
*/
49+
public ClusteredCassandraTable(@NonNull CqlSession session, @NonNull String keyspaceName, @NonNull String tableName) {
50+
super(session, keyspaceName, tableName);
51+
createSchema();
52+
findPartitionStatement = session.prepare(QueryBuilder.selectFrom(tableName).all()
53+
.whereColumn(PARTITION_ID).isEqualTo(QueryBuilder.bindMarker())
54+
.build());
55+
deletePartitionStatement = session.prepare(QueryBuilder.deleteFrom(tableName)
56+
.whereColumn(PARTITION_ID).isEqualTo(QueryBuilder.bindMarker())
57+
.build());
58+
findRowStatement = session.prepare(QueryBuilder.selectFrom(tableName).all()
59+
.whereColumn(PARTITION_ID).isEqualTo(QueryBuilder.bindMarker())
60+
.whereColumn(ROW_ID).isEqualTo(QueryBuilder.bindMarker())
61+
.build());
62+
deleteRowStatement = session.prepare(QueryBuilder.deleteFrom(tableName)
63+
.whereColumn(PARTITION_ID).isEqualTo(QueryBuilder.bindMarker())
64+
.whereColumn(ROW_ID).isEqualTo(QueryBuilder.bindMarker())
65+
.build());
66+
insertRowStatement = session.prepare(QueryBuilder.insertInto(tableName)
67+
.value(PARTITION_ID, QueryBuilder.bindMarker())
68+
.value(ROW_ID, QueryBuilder.bindMarker())
69+
.value(BODY_BLOB, QueryBuilder.bindMarker())
70+
.build());
71+
}
72+
73+
@Override
74+
public void createSchema() {
75+
cqlSession.execute(SchemaBuilder.createTable(tableName)
76+
.ifNotExists()
77+
.withPartitionKey(PARTITION_ID, DataTypes.TEXT)
78+
.withClusteringColumn(ROW_ID, DataTypes.TIMEUUID)
79+
.withColumn(BODY_BLOB, DataTypes.TEXT)
80+
.withClusteringOrder(ROW_ID, ClusteringOrder.DESC)
81+
.build());
82+
log.info("+ Table '{}' has been created (if needed).", tableName);
83+
}
84+
85+
/** {@inheritDoc} */
86+
@Override
87+
public void put(@NonNull ClusteredCassandraTable.Record row) {
88+
cqlSession.execute(insertRowStatement.bind(row.getPartitionId(), row.getRowId(), row.getBody()));
89+
}
90+
91+
/** {@inheritDoc} */
92+
@Override
93+
public Record mapRow(@NonNull Row row) {
94+
return new Record(
95+
row.getString(PARTITION_ID),
96+
row.getUuid(ROW_ID),
97+
row.getString(BODY_BLOB));
98+
}
99+
100+
/**
101+
* Find a partition.
102+
*
103+
* @param partitionDd
104+
* partition id
105+
* @return
106+
* list of rows
107+
*/
108+
public List<Record> findPartition(@NonNull String partitionDd) {
109+
return cqlSession.execute(findPartitionStatement.bind(partitionDd))
110+
.all().stream()
111+
.map(this::mapRow)
112+
.collect(Collectors.toList());
113+
}
114+
115+
/**
116+
* Update the history in one go.
117+
*
118+
* @param rows
119+
* current rows.
120+
*/
121+
public void upsertPartition(List<Record> rows) {
122+
if (rows != null && !rows.isEmpty()) {
123+
BatchStatementBuilder batch = BatchStatement.builder(BatchType.LOGGED);
124+
String currentPartitionId = null;
125+
for (Record row : rows) {
126+
if (currentPartitionId != null && !currentPartitionId.equals(row.getPartitionId())) {
127+
log.warn("Not all rows are part of the same partition");
128+
}
129+
currentPartitionId = row.getPartitionId();
130+
batch.addStatement(insertRowStatement.bind(row.getPartitionId(), row.getRowId(), row.getBody()));
131+
}
132+
cqlSession.execute(batch.build());
133+
}
134+
}
135+
136+
/**
137+
* Find a row by its id.
138+
* @param partition
139+
* partition id
140+
* @param rowId
141+
* row id
142+
* @return
143+
* record if exists
144+
*/
145+
public Optional<Record> findById(String partition, UUID rowId) {
146+
return Optional.ofNullable(cqlSession
147+
.execute(findRowStatement.bind(partition, rowId))
148+
.one()).map(this::mapRow);
149+
}
150+
151+
/**
152+
* Delete Partition.
153+
*
154+
* @param partitionId
155+
* delete the whole partition
156+
*/
157+
public void deletePartition(@NonNull String partitionId) {
158+
cqlSession.execute(deletePartitionStatement.bind(partitionId));
159+
}
160+
161+
/**
162+
* Delete one row.
163+
*
164+
* @param partitionId
165+
* current session
166+
* @param rowId
167+
* message id
168+
*/
169+
public void delete(@NonNull String partitionId, @NonNull UUID rowId) {
170+
cqlSession.execute(deleteRowStatement.bind(partitionId, rowId));
171+
}
172+
173+
/**
174+
* Insert Row.
175+
*
176+
* @param partitionId
177+
* partition id
178+
* @param rowId
179+
* rowId
180+
* @param bodyBlob
181+
* body
182+
*/
183+
public void insert(@NonNull String partitionId, @NonNull UUID rowId, @NonNull String bodyBlob) {
184+
cqlSession.execute(QueryBuilder.insertInto(keyspaceName, tableName)
185+
.value(PARTITION_ID, QueryBuilder.literal(partitionId))
186+
.value(ROW_ID, QueryBuilder.literal(rowId))
187+
.value(BODY_BLOB, QueryBuilder.literal(bodyBlob))
188+
.build());
189+
}
190+
191+
@Data @AllArgsConstructor @NoArgsConstructor
192+
public static class Record {
193+
194+
private String partitionId;
195+
196+
private UUID rowId;
197+
198+
private String body;
199+
}
200+
201+
202+
}

0 commit comments

Comments
 (0)