Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SEARCH_LOAD_PER_INDEX_STATS = def(9_095_0_00);
public static final TransportVersion HEAP_USAGE_IN_CLUSTER_INFO = def(9_096_0_00);
public static final TransportVersion NONE_CHUNKING_STRATEGY = def(9_097_0_00);
public static final TransportVersion RERANK_SNIPPETS = def(9_098_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ public void onFailure(Exception e) {
context.getOriginalIndices(queryResult.getShardIndex()),
queryResult.getContextId(),
queryResult.getShardSearchRequest(),
entry
entry,
rankFeaturePhaseRankCoordinatorContext.snippets()
),
context.getTask(),
listener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ public Field(String name) {
this.name = name;
}

private Field(Field template, QueryBuilder builder) {
Field(Field template, QueryBuilder builder) {
super(template, builder);
name = template.name;
fragmentOffset = template.fragmentOffset;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public static class Field {
private final String field;
private final FieldOptions fieldOptions;

Field(String field, FieldOptions fieldOptions) {
public Field(String field, FieldOptions fieldOptions) {
assert field != null;
assert fieldOptions != null;
this.field = field;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
import org.elasticsearch.search.rank.feature.RerankSnippetInput;

import java.util.Arrays;
import java.util.Comparator;
Expand All @@ -30,18 +31,30 @@ public abstract class RankFeaturePhaseRankCoordinatorContext {
protected final int from;
protected final int rankWindowSize;
protected final boolean failuresAllowed;
protected final RerankSnippetInput snippets;

public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, boolean failuresAllowed) {
public RankFeaturePhaseRankCoordinatorContext(
int size,
int from,
int rankWindowSize,
boolean failuresAllowed,
RerankSnippetInput snippets
) {
this.size = size < 0 ? DEFAULT_SIZE : size;
this.from = from < 0 ? DEFAULT_FROM : from;
this.rankWindowSize = rankWindowSize;
this.failuresAllowed = failuresAllowed;
this.snippets = snippets;
}

public boolean failuresAllowed() {
return failuresAllowed;
}

public RerankSnippetInput snippets() {
return snippets;
}

/**
* Computes the updated scores for a list of features (i.e. document-based data). We also pass along an ActionListener
* that should be called with the new scores, and will continue execution to the next phase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
package org.elasticsearch.search.rank.feature;

import org.apache.lucene.search.Explanation;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

/**
Expand All @@ -27,6 +29,8 @@ public class RankFeatureDoc extends RankDoc {

// TODO: update to support more than 1 fields; and not restrict to string data
public String featureData;
public List<String> snippets;
public List<Integer> docIndices;

public RankFeatureDoc(int doc, float score, int shardIndex) {
super(doc, score, shardIndex);
Expand All @@ -35,6 +39,10 @@ public RankFeatureDoc(int doc, float score, int shardIndex) {
public RankFeatureDoc(StreamInput in) throws IOException {
super(in);
featureData = in.readOptionalString();
if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
snippets = in.readOptionalStringCollectionAsList();
docIndices = in.readOptionalCollectionAsList(StreamInput::readVInt);
}
}

@Override
Expand All @@ -46,20 +54,34 @@ public void featureData(String featureData) {
this.featureData = featureData;
}

public void snippets(List<String> snippets) {
this.snippets = snippets;
}

public void docIndices(List<Integer> docIndices) {
this.docIndices = docIndices;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeOptionalString(featureData);
if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
out.writeOptionalStringCollection(snippets);
out.writeOptionalCollection(docIndices, StreamOutput::writeVInt);
}
}

@Override
protected boolean doEquals(RankDoc rd) {
RankFeatureDoc other = (RankFeatureDoc) rd;
return Objects.equals(this.featureData, other.featureData);
return Objects.equals(this.featureData, other.featureData)
&& Objects.equals(this.snippets, other.snippets)
&& Objects.equals(this.docIndices, other.docIndices);
}

@Override
protected int doHashCode() {
return Objects.hashCode(featureData);
return Objects.hash(featureData, snippets, docIndices);
}

@Override
Expand All @@ -70,5 +92,7 @@ public String getWriteableName() {
@Override
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field("featureData", featureData);
builder.array("snippets", snippets);
builder.array("docIndices", docIndices);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.fetch.subphase.FetchFieldsContext;
import org.elasticsearch.search.fetch.subphase.FieldAndFormat;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.fetch.subphase.highlight.SearchHighlightContext;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
import org.elasticsearch.tasks.TaskCancelledException;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;

Expand Down Expand Up @@ -48,10 +51,30 @@ public static void prepareForFetch(SearchContext searchContext, RankFeatureShard

RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = shardContext(searchContext);
if (rankFeaturePhaseRankShardContext != null) {
assert rankFeaturePhaseRankShardContext.getField() != null : "field must not be null";
String field = rankFeaturePhaseRankShardContext.getField();
assert field != null : "field must not be null";
searchContext.fetchFieldsContext(
new FetchFieldsContext(Collections.singletonList(new FieldAndFormat(rankFeaturePhaseRankShardContext.getField(), null)))
);
try {
RerankSnippetInput snippets = request.snippets();
if (snippets != null) {
// For POC purposes we're just stripping pre/post tags and deferring if/how we'd want to handle them for this use case.
HighlightBuilder highlightBuilder = new HighlightBuilder().field(field).preTags("").postTags("");
// Force sorting by score to ensure that the first snippet is always the highest score
highlightBuilder.order(HighlightBuilder.Order.SCORE);
if (snippets.numFragments() != null) {
highlightBuilder.numOfFragments(snippets.numFragments());
}
if (snippets.maxSize() != null) {
highlightBuilder.fragmentSize(snippets.maxSize());
}
SearchHighlightContext searchHighlightContext = highlightBuilder.build(searchContext.getSearchExecutionContext());
searchContext.highlight(searchHighlightContext);
}
} catch (IOException e) {
throw new RuntimeException("Failed to create highlight context", e);
}
searchContext.storedFieldsContext(StoredFieldsContext.fromList(Collections.singletonList(StoredFieldsContext._NONE_)));
searchContext.addFetchResult();
Arrays.sort(request.getDocIds());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@

package org.elasticsearch.search.rank.feature;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.IndicesRequest;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.search.internal.ShardSearchContextId;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.tasks.TaskId;
Expand All @@ -38,16 +40,20 @@ public class RankFeatureShardRequest extends AbstractTransportRequest implements

private final int[] docIds;

private final RerankSnippetInput snippets;

public RankFeatureShardRequest(
OriginalIndices originalIndices,
ShardSearchContextId contextId,
ShardSearchRequest shardSearchRequest,
List<Integer> docIds
List<Integer> docIds,
@Nullable RerankSnippetInput snippets
) {
this.originalIndices = originalIndices;
this.shardSearchRequest = shardSearchRequest;
this.docIds = docIds.stream().flatMapToInt(IntStream::of).toArray();
this.contextId = contextId;
this.snippets = snippets;
}

public RankFeatureShardRequest(StreamInput in) throws IOException {
Expand All @@ -56,6 +62,11 @@ public RankFeatureShardRequest(StreamInput in) throws IOException {
shardSearchRequest = in.readOptionalWriteable(ShardSearchRequest::new);
docIds = in.readIntArray();
contextId = in.readOptionalWriteable(ShardSearchContextId::new);
if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
snippets = in.readOptionalWriteable(RerankSnippetInput::new);
} else {
snippets = null;
}
}

@Override
Expand All @@ -65,6 +76,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalWriteable(shardSearchRequest);
out.writeIntArray(docIds);
out.writeOptionalWriteable(contextId);
if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
out.writeOptionalWriteable(snippets);
}
}

@Override
Expand Down Expand Up @@ -95,6 +109,10 @@ public ShardSearchContextId contextId() {
return contextId;
}

public RerankSnippetInput snippets() {
return snippets;
}

@Override
public SearchShardTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new SearchShardTask(id, type, action, getDescription(), parentTaskId, headers);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.rank.feature;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;

import java.io.IOException;

public record RerankSnippetInput(Integer numFragments, Integer maxSize) implements Writeable {

public RerankSnippetInput(StreamInput in) throws IOException {
this(in.readOptionalVInt(), in.readOptionalVInt());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalVInt(numFragments);
out.writeOptionalVInt(maxSize);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;
import org.elasticsearch.search.rank.RankShardResult;
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
import org.elasticsearch.search.rank.feature.RankFeatureShardResult;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
* The {@code ReRankingRankFeaturePhaseRankShardContext} is handles the {@code SearchHits} generated from the {@code RankFeatureShardPhase}
Expand All @@ -38,12 +43,27 @@ public RerankingRankFeaturePhaseRankShardContext(String field) {
public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
try {
RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length];
int docIndex = 0;
for (int i = 0; i < hits.getHits().length; i++) {
rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId);
DocumentField docField = hits.getHits()[i].field(field);
SearchHit hit = hits.getHits()[i];
DocumentField docField = hit.field(field);
if (docField != null) {
rankFeatureDocs[i].featureData(docField.getValue().toString());
}
Map<String, HighlightField> highlightFields = hit.getHighlightFields();
if (highlightFields != null) {
if (highlightFields.containsKey(field)) {
List<String> snippets = Arrays.stream(highlightFields.get(field).fragments()).map(Text::string).toList();
List<Integer> docIndices = new ArrayList<>();
for (String snippet : snippets) {
docIndices.add(docIndex);
}
rankFeatureDocs[i].snippets(snippets);
rankFeatureDocs[i].docIndices(docIndices);
}
}
docIndex++;
}
return new RankFeatureShardResult(rankFeatureDocs);
} catch (Exception ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class RandomRankFeaturePhaseRankCoordinatorContext extends RankFeaturePha
private final Integer seed;

public RandomRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, Integer seed) {
super(size, from, rankWindowSize, false);
super(size, from, rankWindowSize, false, null);
this.seed = seed;
}

Expand Down
Loading
Loading