Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ public enum FeatureFlag {
"es.index_dimensions_tsid_optimization_feature_flag_enabled=true",
Version.fromString("9.2.0"),
null
),
ELASTIC_RERANKER_CHUNKING("es.elastic_reranker_chunking_long_documents=true", Version.fromString("9.2.0"), null);
);

public final String systemProperty;
public final Version from;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ public List<String> getChunkedInputs() {
return chunkedInputs;
}

public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener(ActionListener<InferenceServiceResults> listener) {
public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener(
ActionListener<InferenceServiceResults> listener,
boolean returnDocuments
) {
return ActionListener.wrap(results -> {
if (results instanceof RankedDocsResults rankedDocsResults) {
listener.onResponse(parseRankedDocResultsForChunks(rankedDocsResults));
listener.onResponse(parseRankedDocResultsForChunks(rankedDocsResults, returnDocuments));

} else {
listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass()));
Expand All @@ -65,7 +68,7 @@ public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener
}, listener::onFailure);
}

private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) {
private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults, boolean returnDocuments) {
List<RankedDocsResults.RankedDoc> topRankedDocs = new ArrayList<>();
Set<Integer> docIndicesSeen = new HashSet<>();

Expand All @@ -80,7 +83,7 @@ private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults ranke
RankedDocsResults.RankedDoc updatedRankedDoc = new RankedDocsResults.RankedDoc(
docIndex,
rankedDoc.relevanceScore(),
inputs.get(docIndex)
returnDocuments ? inputs.get(docIndex) : null
);
topRankedDocs.add(updatedRankedDoc);
docIndicesSeen.add(docIndex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.ELASTIC_RERANKER_CHUNKING;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.RERANKER_ID;

public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings {
Expand Down Expand Up @@ -102,30 +101,26 @@ public static ElasticRerankerServiceSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();
var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException);

LongDocumentStrategy longDocumentStrategy = null;
Integer maxChunksPerDoc = null;
if (ELASTIC_RERANKER_CHUNKING.isEnabled()) {
longDocumentStrategy = extractOptionalEnum(
map,
LONG_DOCUMENT_STRATEGY,
ModelConfigurations.SERVICE_SETTINGS,
LongDocumentStrategy::fromString,
EnumSet.allOf(LongDocumentStrategy.class),
validationException
LongDocumentStrategy longDocumentStrategy = extractOptionalEnum(
map,
LONG_DOCUMENT_STRATEGY,
ModelConfigurations.SERVICE_SETTINGS,
LongDocumentStrategy::fromString,
EnumSet.allOf(LongDocumentStrategy.class),
validationException
);

Integer maxChunksPerDoc = extractOptionalPositiveInteger(
map,
MAX_CHUNKS_PER_DOC,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);

if (maxChunksPerDoc != null && (longDocumentStrategy == null || longDocumentStrategy == LongDocumentStrategy.TRUNCATE)) {
validationException.addValidationError(
"The [" + MAX_CHUNKS_PER_DOC + "] setting requires [" + LONG_DOCUMENT_STRATEGY + "] to be set to [chunk]"
);

maxChunksPerDoc = extractOptionalPositiveInteger(
map,
MAX_CHUNKS_PER_DOC,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);

if (maxChunksPerDoc != null && (longDocumentStrategy == null || longDocumentStrategy == LongDocumentStrategy.TRUNCATE)) {
validationException.addValidationError(
"The [" + MAX_CHUNKS_PER_DOC + "] setting requires [" + LONG_DOCUMENT_STRATEGY + "] to be set to [chunk]"
);
}
}

if (validationException.validationErrors().isEmpty() == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.elasticsearch.common.logging.DeprecationCategory;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.FeatureFlag;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
Expand Down Expand Up @@ -116,8 +115,6 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class);

public static final FeatureFlag ELASTIC_RERANKER_CHUNKING = new FeatureFlag("elastic_reranker_chunking_long_documents");

/**
* Fix for https://github.com/elastic/elasticsearch/issues/124675
* In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use
Expand Down Expand Up @@ -698,26 +695,26 @@ public void inferRerank(
}
});

if (model instanceof ElasticRerankerModel elasticRerankerModel && ELASTIC_RERANKER_CHUNKING.isEnabled()) {
var returnDocs = Boolean.TRUE;
if (returnDocuments != null) {
returnDocs = returnDocuments;
} else if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) {
var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings);
returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
}

if (model instanceof ElasticRerankerModel elasticRerankerModel) {
var serviceSettings = elasticRerankerModel.getServiceSettings();
var longDocumentStrategy = serviceSettings.getLongDocumentStrategy();
if (longDocumentStrategy == ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK) {
var rerankChunker = new RerankRequestChunker(query, inputs, serviceSettings.getMaxChunksPerDoc());
inputs = rerankChunker.getChunkedInputs();
resultsListener = rerankChunker.parseChunkedRerankResultsListener(resultsListener);
resultsListener = rerankChunker.parseChunkedRerankResultsListener(resultsListener, returnDocs);
}

}
var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout);

var returnDocs = Boolean.TRUE;
if (returnDocuments != null) {
returnDocs = returnDocuments;
} else if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) {
var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings);
returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
}

Function<Integer, String> inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null;

ActionListener<InferModelAction.Response> mlResultsListener = resultsListener.delegateFailureAndWrap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ public void testParseChunkedRerankResultsListener_NonRankedDocsResults() {
ActionListener.wrap(
results -> fail("Expected failure but got: " + results.getClass()),
e -> assertTrue(e instanceof IllegalArgumentException && e.getMessage().contains("Expected RankedDocsResults"))
)
),
randomBoolean()
);

listener.onResponse(new InferenceServiceResults() {
Expand All @@ -124,18 +125,24 @@ public void testParseChunkedRerankResultsListener_EmptyInput() {
assertThat(results, instanceOf(RankedDocsResults.class));
var rankedDocResults = (RankedDocsResults) results;
assertEquals(0, rankedDocResults.getRankedDocs().size());
}, e -> fail("Expected successful parsing but got failure: " + e)));
}, e -> fail("Expected successful parsing but got failure: " + e)), randomBoolean());
listener.onResponse(new RankedDocsResults(List.of()));
}

public void testParseChunkedRerankResultsListener_SingleInputWithoutChunking() {
var inputs = List.of(generateTestText(10));
var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null);
var returnDocuments = randomBoolean();
var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> {
assertThat(results, instanceOf(RankedDocsResults.class));
var rankedDocResults = (RankedDocsResults) results;
assertEquals(1, rankedDocResults.getRankedDocs().size());
}, e -> fail("Expected successful parsing but got failure: " + e)));
if (returnDocuments) {
assertNotNull(rankedDocResults.getRankedDocs().get(0).text());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than just asserting that the text is non-null in this case, would it be better to assert that it matches the expected text:

assertThat(rankedDocResults.getRankedDocs().get(0).text(), is(inputs.getFirst()));

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to the comment below. We can discuss it in the comments there.

} else {
assertNull(rankedDocResults.getRankedDocs().get(0).text());
}
}, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments);

var chunkedInputs = chunker.getChunkedInputs();
assertEquals(1, chunkedInputs.size());
Expand All @@ -147,35 +154,48 @@ public void testParseChunkedRerankResultsListener_SingleInputWithChunking() {
var relevanceScore1 = randomFloatBetween(0, 1, true);
var relevanceScore2 = randomFloatBetween(0, 1, true);
var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null);
var returnDocuments = randomBoolean();
var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> {
assertThat(results, instanceOf(RankedDocsResults.class));
var rankedDocResults = (RankedDocsResults) results;
assertEquals(1, rankedDocResults.getRankedDocs().size());
var expectedRankedDocs = List.of(new RankedDocsResults.RankedDoc(0, max(relevanceScore1, relevanceScore2), inputs.get(0)));
var expectedRankedDocs = List.of(
new RankedDocsResults.RankedDoc(0, max(relevanceScore1, relevanceScore2), returnDocuments ? inputs.get(0) : null)
);
assertEquals(expectedRankedDocs, rankedDocResults.getRankedDocs());
}, e -> fail("Expected successful parsing but got failure: " + e)));
if (returnDocuments) {
assertNotNull(rankedDocResults.getRankedDocs().get(0).text());
} else {
assertNull(rankedDocResults.getRankedDocs().get(0).text());
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These asserts are redundant, since we already assert that the actual results match the expected results on line 165.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove these assertions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When constructing the expected RankedDoc on line 163, we're setting the text to be the expected value (either the first input or null) based on whether returnDocuments is true, so we would be able to tell if the document string was returned when doing the comparison:

new RankedDocsResults.RankedDoc(0, max(relevanceScore1, relevanceScore2), returnDocuments ? inputs.get(0) : null)

}, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments);

var chunkedInputs = chunker.getChunkedInputs();
assertEquals(3, chunkedInputs.size());
var rankedDocsResults = List.of(
new RankedDocsResults.RankedDoc(0, relevanceScore1, chunkedInputs.get(0)),
new RankedDocsResults.RankedDoc(1, relevanceScore2, chunkedInputs.get(1))
);
// TODO: Sort this so that the assumption that the results are in order holds
listener.onResponse(new RankedDocsResults(rankedDocsResults));
}

public void testParseChunkedRerankResultsListener_MultipleInputsWithoutChunking() {
var inputs = List.of(generateTestText(10), generateTestText(10));
var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null);
var returnDocuments = randomBoolean();
var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> {
assertThat(results, instanceOf(RankedDocsResults.class));
var rankedDocResults = (RankedDocsResults) results;
assertEquals(2, rankedDocResults.getRankedDocs().size());
var sortedResults = new ArrayList<>(rankedDocResults.getRankedDocs());
sortedResults.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore()));
assertEquals(sortedResults, rankedDocResults.getRankedDocs());
}, e -> fail("Expected successful parsing but got failure: " + e)));
if (returnDocuments) {
rankedDocResults.getRankedDocs().forEach(r -> { assertNotNull(r.text()); });
} else {
rankedDocResults.getRankedDocs().forEach(r -> { assertNull(r.text()); });
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it might be possible to assert on the actual value of the text in these tests instead of just whether or not it's null. If we used fixed values for the relevance scores in the ranked docs instead of random ones, then the order of the results would be deterministic and we could know which doc was expected to have which text.

It would also be good to have the inputs list elements not be identical, since if they're the same, then we have no way of making sure that the text is correct when comparing between the two documents. With the test as it is, if there was some weird bug which caused the text from one result to be copied to another result, we would have no way of spotting that.

Copy link
Member Author

@dan-rubinstein dan-rubinstein Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of randomizing the values of the relevance score is to make sure we always properly sort the values at the end and that we always take the highest chunk score per document. The underlying unit being tested shouldn't care whether the scores are in a specific order so I figured randomizing the values would help us test this.

As for the input strings I think we can make 2 changes to make the tests more robust.

  1. Make the input strings different per document to avoid accidentally generating 2 identical ones when passing in 2 documents.
  2. Modify the assertions to specifically check that the correct document string was returned.

Let me know if this makes sense to you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as we can assert that the results contain the correct string, then using random relevance scores is fine, I think.

Ideally, I think that a unit test should test one thing at a time, since that helps pinpoint the specific area that has a bug in the event of a test failure. For example, a test that the list returned by rankedDocResults.getRankedDocs() is sorted by relevance score regardless of the returnDocuments value would use non-random relevance scores for the inputs so that the expected output is fixed, so we don't have to effectively reimplement the sorting logic in the test in order to validate the output. This is just my personal philosophy when it comes to unit testing though, so not something that needs to be changed in this PR.

}, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments);

var chunkedInputs = chunker.getChunkedInputs();
assertEquals(2, chunkedInputs.size());
Expand All @@ -192,14 +212,20 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithoutChunking(
public void testParseChunkedRerankResultsListener_MultipleInputsWithSomeChunking() {
var inputs = List.of(generateTestText(10), generateTestText(100));
var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null);
var returnDocuments = randomBoolean();
var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> {
assertThat(results, instanceOf(RankedDocsResults.class));
var rankedDocResults = (RankedDocsResults) results;
assertEquals(2, rankedDocResults.getRankedDocs().size());
var sortedResults = new ArrayList<>(rankedDocResults.getRankedDocs());
sortedResults.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore()));
assertEquals(sortedResults, rankedDocResults.getRankedDocs());
}, e -> fail("Expected successful parsing but got failure: " + e)));
if (returnDocuments) {
rankedDocResults.getRankedDocs().forEach(r -> { assertNotNull(r.text()); });
} else {
rankedDocResults.getRankedDocs().forEach(r -> { assertNull(r.text()); });
}
}, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments);

var chunkedInputs = chunker.getChunkedInputs();
assertEquals(4, chunkedInputs.size());
Expand All @@ -217,14 +243,20 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithSomeChunking
public void testParseChunkedRerankResultsListener_MultipleInputsWithAllRequiringChunking() {
var inputs = List.of(generateTestText(100), generateTestText(100));
var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null);
var returnDocuments = randomBoolean();
var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> {
assertThat(results, instanceOf(RankedDocsResults.class));
var rankedDocResults = (RankedDocsResults) results;
assertEquals(2, rankedDocResults.getRankedDocs().size());
var sortedResults = new ArrayList<>(rankedDocResults.getRankedDocs());
sortedResults.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore()));
assertEquals(sortedResults, rankedDocResults.getRankedDocs());
}, e -> fail("Expected successful parsing but got failure: " + e)));
if (returnDocuments) {
rankedDocResults.getRankedDocs().forEach(r -> { assertNotNull(r.text()); });
} else {
rankedDocResults.getRankedDocs().forEach(r -> { assertNull(r.text()); });
}
}, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments);

var chunkedInputs = chunker.getChunkedInputs();
assertEquals(6, chunkedInputs.size());
Expand Down
Loading