Skip to content

Commit c288af8

Browse files
authored
Optimize feature retrieval for Cassandra online storage (#24)
Signed-off-by: Khor Shu Heng <[email protected]> Co-authored-by: Khor Shu Heng <[email protected]>
1 parent f60155a commit c288af8

File tree

2 files changed

+48
-14
lines changed

2 files changed

+48
-14
lines changed

storage/connectors/cassandra/src/main/java/feast/storage/connectors/cassandra/retriever/CassandraOnlineRetriever.java

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
*/
1717
package feast.storage.connectors.cassandra.retriever;
1818

19+
import com.datastax.oss.driver.api.core.AsyncPagingIterable;
1920
import com.datastax.oss.driver.api.core.CqlSession;
20-
import com.datastax.oss.driver.api.core.cql.BoundStatement;
21+
import com.datastax.oss.driver.api.core.cql.AsyncResultSet;
22+
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
2123
import com.datastax.oss.driver.api.core.cql.Row;
2224
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
2325
import com.datastax.oss.driver.api.querybuilder.select.Select;
@@ -30,9 +32,11 @@
3032
import java.io.IOException;
3133
import java.nio.ByteBuffer;
3234
import java.util.*;
35+
import java.util.concurrent.CompletableFuture;
36+
import java.util.concurrent.CompletionStage;
37+
import java.util.concurrent.ExecutionException;
3338
import java.util.function.Function;
3439
import java.util.stream.Collectors;
35-
import java.util.stream.StreamSupport;
3640
import org.apache.avro.AvroRuntimeException;
3741
import org.apache.avro.generic.GenericDatumReader;
3842
import org.apache.avro.generic.GenericRecord;
@@ -157,12 +161,40 @@ public Map<ByteBuffer, Row> getFeaturesFromSSTable(
157161
for (String columnFamily : columnFamilies) {
158162
query = query.writeTime(columnFamily).as(columnFamily + EVENT_TIMESTAMP_SUFFIX);
159163
}
160-
query = query.whereColumn(ENTITY_KEY).in(QueryBuilder.bindMarker());
164+
query = query.whereColumn(ENTITY_KEY).isEqualTo(QueryBuilder.bindMarker());
161165

162-
BoundStatement statement = session.prepare(query.build()).bind(rowKeys);
166+
PreparedStatement preparedStatement = session.prepare(query.build());
163167

164-
return StreamSupport.stream(session.execute(statement).spliterator(), false)
165-
.collect(Collectors.toMap((Row row) -> row.getByteBuffer(ENTITY_KEY), Function.identity()));
168+
List<CompletableFuture<AsyncResultSet>> completableAsyncResultSets =
169+
rowKeys.stream()
170+
.map(preparedStatement::bind)
171+
.map(session::executeAsync)
172+
.map(CompletionStage::toCompletableFuture)
173+
.collect(Collectors.toList());
174+
175+
CompletableFuture<Void> allResultComputed =
176+
CompletableFuture.allOf(completableAsyncResultSets.toArray(new CompletableFuture[0]));
177+
178+
Map<ByteBuffer, Row> resultMap;
179+
try {
180+
resultMap =
181+
allResultComputed
182+
.thenApply(
183+
v ->
184+
completableAsyncResultSets.stream()
185+
.map(CompletableFuture::join)
186+
.filter(result -> result.remaining() != 0)
187+
.map(AsyncPagingIterable::one)
188+
.filter(Objects::nonNull)
189+
.collect(
190+
Collectors.toMap(
191+
(Row row) -> row.getByteBuffer(ENTITY_KEY), Function.identity())))
192+
.get();
193+
} catch (InterruptedException | ExecutionException e) {
194+
throw new RuntimeException(e.getMessage());
195+
}
196+
197+
return resultMap;
166198
}
167199

168200
/**

storage/connectors/cassandra/src/main/java/feast/storage/connectors/cassandra/retriever/CassandraSchemaRegistry.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import com.datastax.oss.driver.api.core.CqlSession;
2020
import com.datastax.oss.driver.api.core.cql.BoundStatement;
21+
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
2122
import com.datastax.oss.driver.api.core.cql.Row;
2223
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
2324
import com.datastax.oss.driver.api.querybuilder.select.Select;
@@ -34,6 +35,7 @@
3435

3536
public class CassandraSchemaRegistry {
3637
private final CqlSession session;
38+
private final PreparedStatement preparedStatement;
3739
private final LoadingCache<SchemaReference, GenericDatumReader<GenericRecord>> cache;
3840

3941
private static String SCHEMA_REF_TABLE = "feast_schema_reference";
@@ -67,6 +69,13 @@ public int hashCode() {
6769

6870
public CassandraSchemaRegistry(CqlSession session) {
6971
this.session = session;
72+
String tableName = String.format("\"%s\"", SCHEMA_REF_TABLE);
73+
Select query =
74+
QueryBuilder.selectFrom(tableName)
75+
.column(SCHEMA_COLUMN)
76+
.whereColumn(SCHEMA_REF_COLUMN)
77+
.isEqualTo(QueryBuilder.bindMarker());
78+
this.preparedStatement = session.prepare(query.build());
7079

7180
CacheLoader<SchemaReference, GenericDatumReader<GenericRecord>> schemaCacheLoader =
7281
CacheLoader.from(this::loadReader);
@@ -85,14 +94,7 @@ public GenericDatumReader<GenericRecord> getReader(SchemaReference reference) {
8594
}
8695

8796
private GenericDatumReader<GenericRecord> loadReader(SchemaReference reference) {
88-
String tableName = String.format("\"%s\"", SCHEMA_REF_TABLE);
89-
Select query =
90-
QueryBuilder.selectFrom(tableName)
91-
.column(SCHEMA_COLUMN)
92-
.whereColumn(SCHEMA_REF_COLUMN)
93-
.isEqualTo(QueryBuilder.bindMarker());
94-
95-
BoundStatement statement = session.prepare(query.build()).bind(reference.getSchemaHash());
97+
BoundStatement statement = preparedStatement.bind(reference.getSchemaHash());
9698

9799
Row row = session.execute(statement).one();
98100

0 commit comments

Comments
 (0)