Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
Expand Down Expand Up @@ -185,7 +186,7 @@ public ScoreDoc[] rankQueryPhaseResults(
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
return new RankFeaturePhaseRankShardContext(field) {
@Override
public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
Expand Down Expand Up @@ -330,7 +331,7 @@ public ScoreDoc[] rankQueryPhaseResults(
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT)
return new RankFeaturePhaseRankShardContext(field) {
@Override
Expand All @@ -339,7 +340,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
}
};
else {
return super.buildRankFeaturePhaseShardContext();
return super.buildRankFeaturePhaseShardContext(searchContext);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
Expand Down Expand Up @@ -389,7 +390,7 @@ public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int si
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
return new RerankingRankFeaturePhaseRankShardContext(field);
}

Expand Down Expand Up @@ -532,7 +533,7 @@ public ScoreDoc[] rankQueryPhaseResults(
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT)
return new RankFeaturePhaseRankShardContext(field) {
@Override
Expand All @@ -541,7 +542,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
}
};
else {
return super.buildRankFeaturePhaseShardContext();
return super.buildRankFeaturePhaseShardContext(searchContext);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
Expand Down Expand Up @@ -106,7 +107,7 @@ public int rankWindowSize() {
* Generates a context used to execute the rank feature phase on the shard. This is responsible for retrieving any needed
* feature data, and passing them back to the coordinator through the appropriate {@link RankShardResult}.
*/
public abstract RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext();
public abstract RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext);

/**
* Generates a context used to perform global ranking during the RankFeature phase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public static void processFetch(SearchContext searchContext) {
}

RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = searchContext.request().source().rankBuilder() != null
? searchContext.request().source().rankBuilder().buildRankFeaturePhaseShardContext()
? searchContext.request().source().rankBuilder().buildRankFeaturePhaseShardContext(searchContext)
: null;
if (rankFeaturePhaseRankShardContext != null) {
// TODO: here we populate the profile part of the fetchResult as well
Expand All @@ -94,7 +94,7 @@ public static void processFetch(SearchContext searchContext) {

private static RankFeaturePhaseRankShardContext shardContext(SearchContext searchContext) {
return searchContext.request().source() != null && searchContext.request().source().rankBuilder() != null
? searchContext.request().source().rankBuilder().buildRankFeaturePhaseShardContext()
? searchContext.request().source().rankBuilder().buildRankFeaturePhaseShardContext(searchContext)
: null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.ShardSearchContextId;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.search.rank.RankBuilder;
Expand Down Expand Up @@ -900,7 +901,7 @@ public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int si
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
return rankFeaturePhaseRankShardContext;
}

Expand Down Expand Up @@ -976,7 +977,7 @@ private void buildRankFeatureResult(
try {
hits = SearchHits.unpooled(searchHits, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), maxScore);
// construct the appropriate RankFeatureDoc objects based on the rank builder
RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = shardRankBuilder.buildRankFeaturePhaseShardContext();
RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = shardRankBuilder.buildRankFeaturePhaseShardContext(null);
RankFeatureShardResult rankShardResult = (RankFeatureShardResult) rankFeaturePhaseRankShardContext.buildRankFeatureShardResult(
hits,
shardTarget.getShardId().id()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ public RankShardResult combineQueryPhaseResults(List<TopDocs> rankResults) {
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) {
@Override
public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
Expand Down Expand Up @@ -748,7 +748,7 @@ public RankShardResult combineQueryPhaseResults(List<TopDocs> rankResults) {
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) {
@Override
public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
Expand Down Expand Up @@ -875,7 +875,7 @@ public RankShardResult combineQueryPhaseResults(List<TopDocs> rankResults) {
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) {
@Override
public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
Expand Down Expand Up @@ -1008,7 +1008,7 @@ public RankShardResult combineQueryPhaseResults(List<TopDocs> rankResults) {
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) {
@Override
public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
Expand Down Expand Up @@ -1136,7 +1136,7 @@ public RankShardResult combineQueryPhaseResults(List<TopDocs> rankResults) {
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) {
@Override
public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int si
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
return new RankFeaturePhaseRankShardContext(field) {
@Override
public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
Expand Down Expand Up @@ -100,7 +101,7 @@ public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int si
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ public class InferenceFeatures implements FeatureSpecification {
private static final NodeFeature TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE = new NodeFeature(
"test_reranking_service.parse_text_as_score"
);
private static final NodeFeature RERANKING_CHECK_FIELD_EXISTS = new NodeFeature(
"text_similarity_reranker.check_field_exists"
);

@Override
public Set<NodeFeature> getTestFeatures() {
Expand All @@ -50,6 +53,7 @@ public Set<NodeFeature> getTestFeatures() {
SEMANTIC_TEXT_HIGHLIGHTER_DEFAULT,
SEMANTIC_KNN_FILTER_FIX,
TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE,
RERANKING_CHECK_FIELD_EXISTS,
SemanticTextFieldMapper.SEMANTIC_TEXT_BIT_VECTOR_SUPPORT,
SemanticTextFieldMapper.SEMANTIC_TEXT_HANDLE_EMPTY_INPUT
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
Expand Down Expand Up @@ -139,7 +140,7 @@ public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int si
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
return new RerankingRankFeaturePhaseRankShardContext(field);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
Expand Down Expand Up @@ -165,7 +167,20 @@ public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int si
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {

// check field in mapping
Mapper mapper;
try {
mapper = searchContext.indexShard().mapperService().mappingLookup().getMapper(field);
} catch (NullPointerException e) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't use this control flow like this. Why is this throwing an NPE at all? Is it because the mapper service is null or the search context isn't the correct context?

If so, please verify the search context has the required information before making the check. Otherwise, its possible that this gets false positives (e.g. something marked as missing when it actually isn't).

mapper = null;
}

if (mapper == null) {
throw new IllegalArgumentException("field [" + field + "] does not exist in mapping");
}

return new RerankingRankFeaturePhaseRankShardContext(field);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,18 @@ public void testRerankInputSizeAndInferenceResultsMismatch() {
assertThat(ex.getDetailedMessage(), containsString("Reranker input document count and returned score count mismatch"));
}

public void testRerankInputSizeAndInferenceResultsFieldMissing() {
SearchPhaseExecutionException ex = expectThrows(
SearchPhaseExecutionException.class,
// Execute search with text similarity reranking
client.prepareSearch()
.setRankBuilder(new TextSimilarityRankBuilder("missing_field", "my-rerank-model", "my query", 100, 0.0f, false))
.setQuery(QueryBuilders.matchAllQuery())
);
assertThat(ex.status(), equalTo(RestStatus.BAD_REQUEST));
assertThat(ex.getDetailedMessage(), containsString("field [missing_field] does not exist in mapping"));
}

private static Matcher<SearchHit> searchHitWith(int expectedRank, float expectedScore, String expectedText) {
return allOf(
hasRank(expectedRank),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankShardResult;
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
Expand Down Expand Up @@ -193,7 +194,7 @@ public void doWriteTo(StreamOutput out) throws IOException {
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
if (this.throwingRankBuilderType == AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT)
return new RankFeaturePhaseRankShardContext(field()) {
@Override
Expand All @@ -202,7 +203,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
}
};
else {
return super.buildRankFeaturePhaseShardContext();
return super.buildRankFeaturePhaseShardContext(searchContext);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,29 @@ setup:
- match: { hits.total.value: 1 }
- length: { hits.hits: 1 }
- match: { hits.hits.0._id: "doc_1" }

---
"Text similarity reranking fails if the rerank field is missing":
- requires:
cluster_features: "text_similarity_reranker.check_field_exists"
reason: "text_similarity_reranker will check if field exists"

- do:
catch: /field \[missing_field\] does not exist in mapping/
search:
index: test-index
body:
track_total_hits: true
fields: [ "text", "topic" ]
retriever:
text_similarity_reranker:
retriever:
standard:
query:
term:
topic: "science"
rank_window_size: 10
inference_id: my-rerank-model
inference_text: "How often does the moon hide the sun?"
field: missing_field
size: 10
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
Expand Down Expand Up @@ -183,7 +184,7 @@ public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int si
}

@Override
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext(SearchContext searchContext) {
return null;
}

Expand Down