Skip to content

Commit e32fb25

Browse files
committed
FindAndRerank
1 parent 1daf15c commit e32fb25

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1554
-1155
lines changed

astra-db-java/src/main/java/com/datastax/astra/client/collections/Collection.java

Lines changed: 123 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@
2020
* #L%
2121
*/
2222

23-
import com.datastax.astra.client.collections.definition.CollectionDefaultIdTypes;
24-
import com.datastax.astra.client.collections.definition.CollectionDefinition;
25-
import com.datastax.astra.client.collections.definition.CollectionDescriptor;
26-
import com.datastax.astra.client.collections.definition.documents.Document;
2723
import com.datastax.astra.client.collections.commands.ReturnDocument;
2824
import com.datastax.astra.client.collections.commands.Update;
29-
import com.datastax.astra.client.collections.exceptions.TooManyDocumentsToCountException;
25+
import com.datastax.astra.client.collections.commands.cursor.CollectionFindAndRerankCursor;
26+
import com.datastax.astra.client.collections.commands.cursor.CollectionFindCursor;
3027
import com.datastax.astra.client.collections.commands.options.CollectionDeleteManyOptions;
3128
import com.datastax.astra.client.collections.commands.options.CollectionDeleteOneOptions;
29+
import com.datastax.astra.client.collections.commands.options.CollectionFindAndRerankOptions;
3230
import com.datastax.astra.client.collections.commands.options.CollectionFindOneAndDeleteOptions;
3331
import com.datastax.astra.client.collections.commands.options.CollectionFindOneAndReplaceOptions;
3432
import com.datastax.astra.client.collections.commands.options.CollectionFindOneAndUpdateOptions;
@@ -38,43 +36,56 @@
3836
import com.datastax.astra.client.collections.commands.options.CollectionInsertOneOptions;
3937
import com.datastax.astra.client.collections.commands.options.CollectionReplaceOneOptions;
4038
import com.datastax.astra.client.collections.commands.options.CollectionUpdateManyOptions;
39+
import com.datastax.astra.client.collections.commands.options.CollectionUpdateOneOptions;
4140
import com.datastax.astra.client.collections.commands.options.CountDocumentsOptions;
4241
import com.datastax.astra.client.collections.commands.options.EstimatedCountDocumentsOptions;
43-
import com.datastax.astra.client.collections.commands.options.CollectionUpdateOneOptions;
4442
import com.datastax.astra.client.collections.commands.results.CollectionDeleteResult;
4543
import com.datastax.astra.client.collections.commands.results.CollectionInsertManyResult;
4644
import com.datastax.astra.client.collections.commands.results.CollectionInsertOneResult;
4745
import com.datastax.astra.client.collections.commands.results.CollectionUpdateResult;
4846
import com.datastax.astra.client.collections.commands.results.FindOneAndReplaceResult;
49-
import com.datastax.astra.client.core.options.BaseOptions;
47+
import com.datastax.astra.client.collections.definition.CollectionDefaultIdTypes;
48+
import com.datastax.astra.client.collections.definition.CollectionDefinition;
49+
import com.datastax.astra.client.collections.definition.CollectionDescriptor;
50+
import com.datastax.astra.client.collections.definition.documents.Document;
51+
import com.datastax.astra.client.collections.definition.documents.types.ObjectId;
52+
import com.datastax.astra.client.collections.definition.documents.types.UUIDv6;
53+
import com.datastax.astra.client.collections.definition.documents.types.UUIDv7;
54+
import com.datastax.astra.client.collections.exceptions.TooManyDocumentsToCountException;
55+
import com.datastax.astra.client.core.DataAPIKeywords;
5056
import com.datastax.astra.client.core.commands.Command;
51-
import com.datastax.astra.client.collections.commands.cursor.CollectionFindCursor;
57+
import com.datastax.astra.client.core.options.BaseOptions;
5258
import com.datastax.astra.client.core.paging.Page;
5359
import com.datastax.astra.client.core.query.Filter;
5460
import com.datastax.astra.client.core.query.Filters;
55-
import com.datastax.astra.client.core.DataAPIKeywords;
56-
import com.datastax.astra.client.collections.definition.documents.types.ObjectId;
57-
import com.datastax.astra.client.collections.definition.documents.types.UUIDv6;
58-
import com.datastax.astra.client.collections.definition.documents.types.UUIDv7;
5961
import com.datastax.astra.client.core.query.Projection;
62+
import com.datastax.astra.client.core.reranking.RerankResult;
6063
import com.datastax.astra.client.core.vector.DataAPIVector;
6164
import com.datastax.astra.client.databases.Database;
6265
import com.datastax.astra.client.exceptions.DataAPIException;
6366
import com.datastax.astra.client.exceptions.UnexpectedDataAPIResponseException;
6467
import com.datastax.astra.client.tables.commands.options.TableDistinctOptions;
68+
import com.datastax.astra.client.tables.definition.rows.Row;
6569
import com.datastax.astra.internal.api.DataAPIResponse;
6670
import com.datastax.astra.internal.api.DataAPIStatus;
6771
import com.datastax.astra.internal.command.AbstractCommandRunner;
6872
import com.datastax.astra.internal.serdes.DataAPISerializer;
6973
import com.datastax.astra.internal.serdes.collections.DocumentSerializer;
74+
import com.datastax.astra.internal.serdes.tables.RowMapper;
7075
import com.datastax.astra.internal.utils.Assert;
7176
import com.datastax.astra.internal.utils.EscapeUtils;
7277
import lombok.Getter;
7378
import lombok.extern.slf4j.Slf4j;
7479

7580
import java.time.Duration;
7681
import java.time.Instant;
77-
import java.util.*;
82+
import java.util.ArrayList;
83+
import java.util.Arrays;
84+
import java.util.List;
85+
import java.util.Map;
86+
import java.util.Optional;
87+
import java.util.Set;
88+
import java.util.UUID;
7889
import java.util.concurrent.Callable;
7990
import java.util.concurrent.CompletableFuture;
8091
import java.util.concurrent.ExecutionException;
@@ -91,6 +102,7 @@
91102
import static com.datastax.astra.client.core.options.DataAPIClientOptions.MAX_COUNT;
92103
import static com.datastax.astra.client.exceptions.DataAPIException.ERROR_CODE_INTERRUPTED;
93104
import static com.datastax.astra.client.exceptions.DataAPIException.ERROR_CODE_TIMEOUT;
105+
import static com.datastax.astra.internal.serdes.tables.RowMapper.mapFromRow;
94106
import static com.datastax.astra.internal.utils.AnsiUtils.cyan;
95107
import static com.datastax.astra.internal.utils.AnsiUtils.green;
96108
import static com.datastax.astra.internal.utils.AnsiUtils.magenta;
@@ -1051,6 +1063,96 @@ public CollectionFindCursor<T, T> findAll() {
10511063
return find(null, new CollectionFindOptions());
10521064
}
10531065

1066+
public Page<T> findPage(Filter filter, CollectionFindOptions options) {
1067+
return findPage(filter, options, getDocumentClass());
1068+
}
1069+
1070+
1071+
// -----------------------------
1072+
// --- Find and Rerank ----
1073+
// -----------------------------
1074+
1075+
/**
1076+
* Finds all documents in the collection.
1077+
*
1078+
* @param filter
1079+
* the query filter
1080+
* @param options
1081+
* options of find one
1082+
* @return
1083+
* the find iterable interface
1084+
*/
1085+
public CollectionFindAndRerankCursor<T,T> findAndRerank(Filter filter, CollectionFindAndRerankOptions options) {
1086+
return findAndRerank(filter, options, getDocumentClass());
1087+
}
1088+
1089+
public <R> CollectionFindAndRerankCursor<T, R> findAndRerank(Filter filter, CollectionFindAndRerankOptions options, Class<R> newRowType) {
1090+
return new CollectionFindAndRerankCursor<>(this, filter, options, newRowType);
1091+
}
1092+
1093+
public <R> Page<RerankResult<R>> findAndRerankPage(Filter filter, CollectionFindAndRerankOptions options, Class<R> newRowType) {
1094+
Command findAndRerankCommand = Command
1095+
.create("findAndRerank")
1096+
.withFilter(filter);
1097+
if (options != null) {
1098+
findAndRerankCommand
1099+
.withSort(options.getSortArray())
1100+
.withProjection(options.getProjectionArray())
1101+
.withOptions(new Document()
1102+
.appendIfNotNull("rerankOn", options.rerankOn())
1103+
.appendIfNotNull("limit", options.limit())
1104+
.appendIfNotNull("hybridProjection", options.hybridProjection().getValue())
1105+
.appendIfNotNull("hybridLimits", options.hybridLimits())
1106+
.appendIfNotNull(INPUT_INCLUDE_SORT_VECTOR, options.includeSortVector())
1107+
.appendIfNotNull(INPUT_INCLUDE_SIMILARITY, options.includeSimilarity())
1108+
)
1109+
;
1110+
}
1111+
1112+
// Responses MOCK for now
1113+
DataAPIResponse apiResponse = runCommand(findAndRerankCommand, options);
1114+
1115+
// load sortVector if available
1116+
DataAPIVector sortVector = null;
1117+
if (options != null && options.includeSortVector() != null && apiResponse.getStatus() != null) {
1118+
sortVector = apiResponse.getStatus().getSortVector();
1119+
}
1120+
1121+
List<RerankResult<R>> results = new ArrayList<>();
1122+
List<Document> documents = apiResponse.getData().getDocuments();
1123+
List<Document> documentResponses = apiResponse.getStatus().getDocumentResponses();
1124+
if (documents == null || documentResponses == null) {
1125+
throw new UnexpectedDataAPIResponseException(findAndRerankCommand,
1126+
apiResponse, "Documents or Documents reponses are not retuned");
1127+
}
1128+
if (documents.size() != documentResponses.size()) {
1129+
throw new UnexpectedDataAPIResponseException(findAndRerankCommand,
1130+
apiResponse, "Documents or Documents responses do not match");
1131+
}
1132+
1133+
for(int i = 0; i < documents.size(); i++) {
1134+
1135+
// Getting document and projecting as expected
1136+
Document document = documents.get(i);
1137+
1138+
// MAP WITH DOCUMENT FUNCTION
1139+
DocumentSerializer serializer = new DocumentSerializer();
1140+
R results1 = serializer.convertValue(document, newRowType);
1141+
1142+
// MAP WITH ROW FUNCTION
1143+
Row row = RowMapper.mapAsRow(document);
1144+
R result = RowMapper.mapFromRow(row, getSerializer(), newRowType);
1145+
1146+
// Getting associated document response
1147+
Document documentResponse = documentResponses.get(i);
1148+
Map<String, Double> scores = documentResponse.getMap("scores", String.class, Double.class);
1149+
1150+
results.add(new RerankResult<>(results1, scores));
1151+
}
1152+
// PageState is always NULL
1153+
return new Page<>(null, results, sortVector);
1154+
}
1155+
10541156
/**
10551157
* Executes a paginated 'find' query on the collection using the specified filter and find options.
10561158
* <p>
@@ -1073,7 +1175,7 @@ public CollectionFindCursor<T, T> findAll() {
10731175
* @param options The {@link CollectionFindOptions} providing additional query parameters, such as sorting and pagination.
10741176
* @return A {@link Page} object containing the documents that match the query, along with pagination information.
10751177
*/
1076-
public Page<T> findPage(Filter filter, CollectionFindOptions options) {
1178+
public <R> Page<R> findPage(Filter filter, CollectionFindOptions options, Class<R> newRowType) {
10771179
Command findCommand = Command
10781180
.create("find")
10791181
.withFilter(filter);
@@ -1094,12 +1196,15 @@ public Page<T> findPage(Filter filter, CollectionFindOptions options) {
10941196
if (options != null && options.includeSortVector() != null && apiResponse.getStatus() != null) {
10951197
sortVector = apiResponse.getStatus().getSortVector();
10961198
}
1097-
10981199
return new Page<>(
10991200
apiResponse.getData().getNextPageState(),
1100-
apiResponse.getData().getDocuments()
1101-
.stream()
1102-
.map(d -> d.map(getDocumentClass()))
1201+
apiResponse.getData().getDocuments().stream()
1202+
.map(d -> {
1203+
Row row = RowMapper.mapAsRow(d);
1204+
return mapFromRow(row, getSerializer(), newRowType);
1205+
})
1206+
// .map(d -> d.map(newRowType))
1207+
//.map(d -> RowMapper.mapFromRow(d, getSerializer(), newRowType))
11031208
.collect(Collectors.toList()), sortVector);
11041209
}
11051210

@@ -1129,23 +1234,6 @@ public CompletableFuture<Page<T>> findPageASync(Filter filter, CollectionFindOpt
11291234
return CompletableFuture.supplyAsync(() -> findPage(filter, options));
11301235
}
11311236

1132-
// -----------------------------
1133-
// --- Find and Rerank ----
1134-
// -----------------------------
1135-
1136-
/**
1137-
* Finds all documents in the collection.
1138-
*
1139-
* @param filter
1140-
* the query filter
1141-
* @param options
1142-
* options of find one
1143-
* @return
1144-
* the find iterable interface
1145-
*/
1146-
public CollectionFindCursor<T, T> findAndRerank(Filter filter, CollectionFindOptions options) {
1147-
return new CollectionFindCursor<>(this, filter, options, getDocumentClass());
1148-
}
11491237

11501238
// -------------------------
11511239
// --- distinct ----

0 commit comments

Comments
 (0)