|
16 | 16 | */ |
17 | 17 | package feast.storage.connectors.cassandra.retriever; |
18 | 18 |
|
| 19 | +import com.datastax.oss.driver.api.core.AsyncPagingIterable; |
19 | 20 | import com.datastax.oss.driver.api.core.CqlSession; |
20 | | -import com.datastax.oss.driver.api.core.cql.BoundStatement; |
| 21 | +import com.datastax.oss.driver.api.core.cql.AsyncResultSet; |
| 22 | +import com.datastax.oss.driver.api.core.cql.PreparedStatement; |
21 | 23 | import com.datastax.oss.driver.api.core.cql.Row; |
22 | 24 | import com.datastax.oss.driver.api.querybuilder.QueryBuilder; |
23 | 25 | import com.datastax.oss.driver.api.querybuilder.select.Select; |
|
30 | 32 | import java.io.IOException; |
31 | 33 | import java.nio.ByteBuffer; |
32 | 34 | import java.util.*; |
| 35 | +import java.util.concurrent.CompletableFuture; |
| 36 | +import java.util.concurrent.CompletionStage; |
| 37 | +import java.util.concurrent.ExecutionException; |
33 | 38 | import java.util.function.Function; |
34 | 39 | import java.util.stream.Collectors; |
35 | | -import java.util.stream.StreamSupport; |
36 | 40 | import org.apache.avro.AvroRuntimeException; |
37 | 41 | import org.apache.avro.generic.GenericDatumReader; |
38 | 42 | import org.apache.avro.generic.GenericRecord; |
@@ -157,12 +161,40 @@ public Map<ByteBuffer, Row> getFeaturesFromSSTable( |
157 | 161 | for (String columnFamily : columnFamilies) { |
158 | 162 | query = query.writeTime(columnFamily).as(columnFamily + EVENT_TIMESTAMP_SUFFIX); |
159 | 163 | } |
160 | | - query = query.whereColumn(ENTITY_KEY).in(QueryBuilder.bindMarker()); |
| 164 | + query = query.whereColumn(ENTITY_KEY).isEqualTo(QueryBuilder.bindMarker()); |
161 | 165 |
|
162 | | - BoundStatement statement = session.prepare(query.build()).bind(rowKeys); |
| 166 | + PreparedStatement preparedStatement = session.prepare(query.build()); |
163 | 167 |
|
164 | | - return StreamSupport.stream(session.execute(statement).spliterator(), false) |
165 | | - .collect(Collectors.toMap((Row row) -> row.getByteBuffer(ENTITY_KEY), Function.identity())); |
| 168 | + List<CompletableFuture<AsyncResultSet>> completableAsyncResultSets = |
| 169 | + rowKeys.stream() |
| 170 | + .map(preparedStatement::bind) |
| 171 | + .map(session::executeAsync) |
| 172 | + .map(CompletionStage::toCompletableFuture) |
| 173 | + .collect(Collectors.toList()); |
| 174 | + |
| 175 | + CompletableFuture<Void> allResultComputed = |
| 176 | + CompletableFuture.allOf(completableAsyncResultSets.toArray(new CompletableFuture[0])); |
| 177 | + |
| 178 | + Map<ByteBuffer, Row> resultMap; |
| 179 | + try { |
| 180 | + resultMap = |
| 181 | + allResultComputed |
| 182 | + .thenApply( |
| 183 | + v -> |
| 184 | + completableAsyncResultSets.stream() |
| 185 | + .map(CompletableFuture::join) |
| 186 | + .filter(result -> result.remaining() != 0) |
| 187 | + .map(AsyncPagingIterable::one) |
| 188 | + .filter(Objects::nonNull) |
| 189 | + .collect( |
| 190 | + Collectors.toMap( |
| 191 | + (Row row) -> row.getByteBuffer(ENTITY_KEY), Function.identity()))) |
| 192 | + .get(); |
| 193 | + } catch (InterruptedException | ExecutionException e) { |
| 194 | + throw new RuntimeException(e.getMessage()); |
| 195 | + } |
| 196 | + |
| 197 | + return resultMap; |
166 | 198 | } |
167 | 199 |
|
168 | 200 | /** |
|
0 commit comments