Skip to content

Commit 31f8a6a

Browse files
Removing elastic reranker chunking feature flag and allow return_documents to be false (#136045) (#136222)
* Removing elastic reranker chunking feature flag * Allow return_documents to be set to false * Updating unit tests to verify returned document strings --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent e9f9aaf commit 31f8a6a

File tree

8 files changed

+158
-239
lines changed

8 files changed

+158
-239
lines changed

test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ public enum FeatureFlag {
2525
"es.index_dimensions_tsid_optimization_feature_flag_enabled=true",
2626
Version.fromString("9.2.0"),
2727
null
28-
),
29-
ELASTIC_RERANKER_CHUNKING("es.elastic_reranker_chunking_long_documents=true", Version.fromString("9.2.0"), null);
28+
);
3029

3130
public final String systemProperty;
3231
public final Version from;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,13 @@ public List<String> getChunkedInputs() {
5353
return chunkedInputs;
5454
}
5555

56-
public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener(ActionListener<InferenceServiceResults> listener) {
56+
public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener(
57+
ActionListener<InferenceServiceResults> listener,
58+
boolean returnDocuments
59+
) {
5760
return ActionListener.wrap(results -> {
5861
if (results instanceof RankedDocsResults rankedDocsResults) {
59-
listener.onResponse(parseRankedDocResultsForChunks(rankedDocsResults));
62+
listener.onResponse(parseRankedDocResultsForChunks(rankedDocsResults, returnDocuments));
6063

6164
} else {
6265
listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass()));
@@ -65,7 +68,7 @@ public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener
6568
}, listener::onFailure);
6669
}
6770

68-
private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) {
71+
private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults, boolean returnDocuments) {
6972
List<RankedDocsResults.RankedDoc> topRankedDocs = new ArrayList<>();
7073
Set<Integer> docIndicesSeen = new HashSet<>();
7174

@@ -80,7 +83,7 @@ private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults ranke
8083
RankedDocsResults.RankedDoc updatedRankedDoc = new RankedDocsResults.RankedDoc(
8184
docIndex,
8285
rankedDoc.relevanceScore(),
83-
inputs.get(docIndex)
86+
returnDocuments ? inputs.get(docIndex) : null
8487
);
8588
topRankedDocs.add(updatedRankedDoc);
8689
docIndicesSeen.add(docIndex);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
2424
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
25-
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.ELASTIC_RERANKER_CHUNKING;
2625
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.RERANKER_ID;
2726

2827
public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings {
@@ -102,30 +101,26 @@ public static ElasticRerankerServiceSettings fromMap(Map<String, Object> map) {
102101
ValidationException validationException = new ValidationException();
103102
var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException);
104103

105-
LongDocumentStrategy longDocumentStrategy = null;
106-
Integer maxChunksPerDoc = null;
107-
if (ELASTIC_RERANKER_CHUNKING.isEnabled()) {
108-
longDocumentStrategy = extractOptionalEnum(
109-
map,
110-
LONG_DOCUMENT_STRATEGY,
111-
ModelConfigurations.SERVICE_SETTINGS,
112-
LongDocumentStrategy::fromString,
113-
EnumSet.allOf(LongDocumentStrategy.class),
114-
validationException
104+
LongDocumentStrategy longDocumentStrategy = extractOptionalEnum(
105+
map,
106+
LONG_DOCUMENT_STRATEGY,
107+
ModelConfigurations.SERVICE_SETTINGS,
108+
LongDocumentStrategy::fromString,
109+
EnumSet.allOf(LongDocumentStrategy.class),
110+
validationException
111+
);
112+
113+
Integer maxChunksPerDoc = extractOptionalPositiveInteger(
114+
map,
115+
MAX_CHUNKS_PER_DOC,
116+
ModelConfigurations.SERVICE_SETTINGS,
117+
validationException
118+
);
119+
120+
if (maxChunksPerDoc != null && (longDocumentStrategy == null || longDocumentStrategy == LongDocumentStrategy.TRUNCATE)) {
121+
validationException.addValidationError(
122+
"The [" + MAX_CHUNKS_PER_DOC + "] setting requires [" + LONG_DOCUMENT_STRATEGY + "] to be set to [chunk]"
115123
);
116-
117-
maxChunksPerDoc = extractOptionalPositiveInteger(
118-
map,
119-
MAX_CHUNKS_PER_DOC,
120-
ModelConfigurations.SERVICE_SETTINGS,
121-
validationException
122-
);
123-
124-
if (maxChunksPerDoc != null && (longDocumentStrategy == null || longDocumentStrategy == LongDocumentStrategy.TRUNCATE)) {
125-
validationException.addValidationError(
126-
"The [" + MAX_CHUNKS_PER_DOC + "] setting requires [" + LONG_DOCUMENT_STRATEGY + "] to be set to [chunk]"
127-
);
128-
}
129124
}
130125

131126
if (validationException.validationErrors().isEmpty() == false) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import org.elasticsearch.common.logging.DeprecationCategory;
1717
import org.elasticsearch.common.logging.DeprecationLogger;
1818
import org.elasticsearch.common.settings.Settings;
19-
import org.elasticsearch.common.util.FeatureFlag;
2019
import org.elasticsearch.common.util.LazyInitializable;
2120
import org.elasticsearch.core.Nullable;
2221
import org.elasticsearch.core.Strings;
@@ -116,8 +115,6 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
116115
private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
117116
private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class);
118117

119-
public static final FeatureFlag ELASTIC_RERANKER_CHUNKING = new FeatureFlag("elastic_reranker_chunking_long_documents");
120-
121118
/**
122119
* Fix for https://github.com/elastic/elasticsearch/issues/124675
123120
* In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use
@@ -698,26 +695,26 @@ public void inferRerank(
698695
}
699696
});
700697

701-
if (model instanceof ElasticRerankerModel elasticRerankerModel && ELASTIC_RERANKER_CHUNKING.isEnabled()) {
698+
var returnDocs = Boolean.TRUE;
699+
if (returnDocuments != null) {
700+
returnDocs = returnDocuments;
701+
} else if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) {
702+
var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings);
703+
returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
704+
}
705+
706+
if (model instanceof ElasticRerankerModel elasticRerankerModel) {
702707
var serviceSettings = elasticRerankerModel.getServiceSettings();
703708
var longDocumentStrategy = serviceSettings.getLongDocumentStrategy();
704709
if (longDocumentStrategy == ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK) {
705710
var rerankChunker = new RerankRequestChunker(query, inputs, serviceSettings.getMaxChunksPerDoc());
706711
inputs = rerankChunker.getChunkedInputs();
707-
resultsListener = rerankChunker.parseChunkedRerankResultsListener(resultsListener);
712+
resultsListener = rerankChunker.parseChunkedRerankResultsListener(resultsListener, returnDocs);
708713
}
709714

710715
}
711716
var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout);
712717

713-
var returnDocs = Boolean.TRUE;
714-
if (returnDocuments != null) {
715-
returnDocs = returnDocuments;
716-
} else if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) {
717-
var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings);
718-
returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
719-
}
720-
721718
Function<Integer, String> inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null;
722719

723720
ActionListener<InferModelAction.Response> mlResultsListener = resultsListener.delegateFailureAndWrap(

0 commit comments

Comments
 (0)