diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java index 0a62e9f968e4f..ad01cbd738611 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java @@ -279,8 +279,8 @@ public final boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; RetrieverBuilder that = (RetrieverBuilder) o; - return Objects.equals(preFilterQueryBuilders, that.preFilterQueryBuilders) - && Objects.equals(minScore, that.minScore) + return Objects.equals(getPreFilterQueryBuilders(), that.getPreFilterQueryBuilders()) + && Objects.equals(minScore(), that.minScore()) && doEquals(o); } @@ -288,7 +288,7 @@ public final boolean equals(Object o) { @Override public final int hashCode() { - return Objects.hash(getClass(), preFilterQueryBuilders, minScore, doHashCode()); + return Objects.hash(getClass(), getPreFilterQueryBuilders(), minScore(), doHashCode()); } protected abstract int doHashCode(); diff --git a/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java index 02d01bb2f45fc..e6592067035a0 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java @@ -104,6 +104,10 @@ private StandardRetrieverBuilder(StandardRetrieverBuilder clone) { this.terminateAfter = clone.terminateAfter; } + public QueryBuilder getQueryBuilder() { + return queryBuilder; + } + @Override public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { boolean changed = false; diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 78f30e7da0670..a9a3f8bc9ec39 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -43,6 +43,7 @@ exports org.elasticsearch.xpack.inference; exports org.elasticsearch.xpack.inference.action.task; exports org.elasticsearch.xpack.inference.telemetry; + exports org.elasticsearch.xpack.inference.rank.textsimilarity; provides org.elasticsearch.features.FeatureSpecification with org.elasticsearch.xpack.inference.InferenceFeatures; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index d6883d3743a1d..51fd95b263be9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -23,6 +23,7 @@ import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -49,7 +50,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> { RetrieverBuilder retrieverBuilder = (RetrieverBuilder) args[0]; - String inferenceId = args[1] == null ? DEFAULT_RERANK_ID : (String) args[1]; + String inferenceId = (String) args[1]; String inferenceText = (String) args[2]; String field = (String) args[3]; int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4]; @@ -104,11 +105,17 @@ public TextSimilarityRankRetrieverBuilder( int rankWindowSize, boolean failuresAllowed ) { - super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize); - this.inferenceId = inferenceId; - this.inferenceText = inferenceText; - this.field = field; - this.failuresAllowed = failuresAllowed; + this( + List.of(new RetrieverSource(retrieverBuilder, null)), + inferenceId, + inferenceText, + field, + rankWindowSize, + null, + failuresAllowed, + null, + new ArrayList<>() + ); } public TextSimilarityRankRetrieverBuilder( @@ -124,9 +131,11 @@ public TextSimilarityRankRetrieverBuilder( ) { super(retrieverSource, rankWindowSize); if (retrieverSource.size() != 1) { - throw new IllegalArgumentException("[" + getName() + "] retriever should have exactly one inner retriever"); + throw new IllegalArgumentException( + "[" + TextSimilarityRankBuilder.NAME + "] retriever should have exactly one inner retriever" + ); } - this.inferenceId = inferenceId; + this.inferenceId = inferenceId == null ? DEFAULT_RERANK_ID : inferenceId; this.inferenceText = inferenceText; this.field = field; this.minScore = minScore; diff --git a/x-pack/plugin/rank-rrf/build.gradle b/x-pack/plugin/rank-rrf/build.gradle index 216e85f48f56f..0588140b47ce8 100644 --- a/x-pack/plugin/rank-rrf/build.gradle +++ b/x-pack/plugin/rank-rrf/build.gradle @@ -13,11 +13,12 @@ esplugin { name = 'rank-rrf' description = 'Reciprocal rank fusion in search.' classname ='org.elasticsearch.xpack.rank.rrf.RRFRankPlugin' - extendedPlugins = ['x-pack-core'] + extendedPlugins = ['x-pack-core', 'x-pack-inference'] } dependencies { compileOnly project(path: xpackModule('core')) + compileOnly project(path: xpackModule('inference')) testImplementation(testArtifact(project(xpackModule('core')))) testImplementation(testArtifact(project(':server'))) diff --git a/x-pack/plugin/rank-rrf/src/main/java/module-info.java b/x-pack/plugin/rank-rrf/src/main/java/module-info.java index fbe467fdf3eae..42bfcdcdc19e4 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/module-info.java +++ b/x-pack/plugin/rank-rrf/src/main/java/module-info.java @@ -13,6 +13,7 @@ requires org.elasticsearch.xcontent; requires org.elasticsearch.server; requires org.elasticsearch.xcore; + requires org.elasticsearch.inference; exports org.elasticsearch.xpack.rank; exports org.elasticsearch.xpack.rank.rrf; diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/HybridRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/HybridRetrieverBuilder.java new file mode 100644 index 0000000000000..341f578d52480 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/HybridRetrieverBuilder.java @@ -0,0 +1,373 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.hybrid; + +import org.elasticsearch.common.collect.ImmutableOpenMap; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.index.query.MatchQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.rank.RankBuilder; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverBuilderWrapper; +import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.support.MapXContentParser; +import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; +import org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder; +import org.elasticsearch.xpack.rank.linear.MinMaxScoreNormalizer; +import org.elasticsearch.xpack.rank.linear.ScoreNormalizer; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.RANK_WINDOW_SIZE_FIELD; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.rank.hybrid.QuerySettings.TYPE_FIELD; + +// TODO: +// - Retriever name support + +public class HybridRetrieverBuilder extends RetrieverBuilderWrapper { + public static final String NAME = "hybrid"; + public static final ParseField FIELDS_FIELD = new ParseField("fields"); + public static final ParseField QUERY_FIELD = new ParseField("query"); + public static final ParseField RERANK_FIELD = new ParseField("rerank"); + public static final ParseField RERANK_FIELD_FIELD = new ParseField("rerank_field"); + public static final ParseField RERANK_INFERENCE_ID_FIELD = new ParseField("rerank_inference_id"); + public static final ParseField QUERY_SETTINGS_FIELD = new ParseField("query_settings"); + + private final List fields; + private final String query; + private final Boolean rerank; + private final String rerankField; + private final String rerankInferenceId; + private final Map> querySettingsMap; + private final int rankWindowSize; + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + false, + (args, context) -> { + List fields = (List) args[0]; + String query = (String) args[1]; + Boolean rerank = (Boolean) args[2]; + String rerankField = (String) args[3]; + String rerankInferenceId = (String) args[4]; + int rankWindowSize = args[5] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[5]; + Map> querySettingsMap = (Map>) args[6]; + + return new HybridRetrieverBuilder(fields, query, rerank, rerankField, rerankInferenceId, querySettingsMap, rankWindowSize); + } + ); + + private static final NamedXContentRegistry NAMED_X_CONTENT_REGISTRY; + + static { + List xContentRegistryEntries = List.of( + new NamedXContentRegistry.Entry( + QuerySettings.class, + new ParseField(MatchQuerySettings.QUERY_TYPE.getQueryName()), + MatchQuerySettings::fromXContent + ), + new NamedXContentRegistry.Entry( + QuerySettings.class, + new ParseField(MatchPhraseQuerySettings.QUERY_TYPE.getQueryName()), + MatchPhraseQuerySettings::fromXContent + ) + ); + + NAMED_X_CONTENT_REGISTRY = new NamedXContentRegistry(xContentRegistryEntries); + + PARSER.declareStringArray(constructorArg(), FIELDS_FIELD); + PARSER.declareString(constructorArg(), QUERY_FIELD); + PARSER.declareBoolean(optionalConstructorArg(), RERANK_FIELD); + PARSER.declareString(optionalConstructorArg(), RERANK_FIELD_FIELD); + PARSER.declareString(optionalConstructorArg(), RERANK_INFERENCE_ID_FIELD); + PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> { + Map> querySettingsMap = new HashMap<>(); + + Map unparsedMap = p.map(); + for (var entry : unparsedMap.entrySet()) { + String field = entry.getKey(); + Object value = entry.getValue(); + + if (value instanceof List list) { + List querySettings = querySettingsMap.computeIfAbsent(field, f -> new ArrayList<>(list.size())); + for (Object listValue : list) { + if (listValue instanceof Map map) { + querySettings.add(parseQuerySettings(map)); + } else { + throw new IllegalArgumentException( + "Query settings for field [" + field + "] must be an object or list of objects" + ); + } + } + } else if (value instanceof Map map) { + List querySettings = querySettingsMap.computeIfAbsent(field, f -> new ArrayList<>()); + querySettings.add(parseQuerySettings(map)); + } else { + throw new IllegalArgumentException("Query settings for field [" + field + "] must be an object or list of objects"); + } + } + + return querySettingsMap; + }, QUERY_SETTINGS_FIELD); + RetrieverBuilder.declareBaseParserFields(PARSER); + } + + public HybridRetrieverBuilder( + List fields, + String query, + Boolean rerank, + String rerankField, + String rerankInferenceId, + Map> querySettingsMap, + int rankWindowSize + ) { + this( + fields == null ? List.of() : List.copyOf(fields), + query, + rerank, + rerankField, + rerankInferenceId, + copyQuerySettingsMap(querySettingsMap), + rankWindowSize, + generateRetrieverBuilder(fields, query, rerank, rerankField, rerankInferenceId, querySettingsMap, rankWindowSize) + ); + } + + private HybridRetrieverBuilder( + List fields, + String query, + Boolean rerank, + String rerankField, + String rerankInferenceId, + Map> querySettingsMap, + int rankWindowSize, + RetrieverBuilder retrieverBuilder + ) { + super(retrieverBuilder); + this.fields = fields; + this.query = query; + this.rerank = rerank; + this.rerankField = rerankField; + this.rerankInferenceId = rerankInferenceId; + this.querySettingsMap = querySettingsMap; + this.rankWindowSize = rankWindowSize; + } + + @Override + protected HybridRetrieverBuilder clone(RetrieverBuilder sub) { + return new HybridRetrieverBuilder(fields, query, rerank, rerankField, rerankInferenceId, querySettingsMap, rankWindowSize, sub); + } + + @Override + public String getName() { + return NAME; + } + + @Override + protected void doToXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(FIELDS_FIELD.getPreferredName(), fields); + builder.field(QUERY_FIELD.getPreferredName(), query); + if (rerank != null) { + builder.field(RERANK_FIELD.getPreferredName(), rerank); + } + if (rerankField != null) { + builder.field(RERANK_FIELD_FIELD.getPreferredName(), rerankField); + } + if (rerankInferenceId != null) { + builder.field(RERANK_INFERENCE_ID_FIELD.getPreferredName(), rerankInferenceId); + } + builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); + } + + @Override + protected boolean doEquals(Object o) { + // TODO: Check rankWindowSize? It should be checked by the wrapped retriever. + HybridRetrieverBuilder that = (HybridRetrieverBuilder) o; + return Objects.equals(fields, that.fields) + && Objects.equals(query, that.query) + && Objects.equals(rerank, that.rerank) + && Objects.equals(rerankField, that.rerankField) + && Objects.equals(rerankInferenceId, that.rerankInferenceId) + && super.doEquals(o); + } + + @Override + protected int doHashCode() { + return Objects.hash(fields, query, rerank, rerankField, rerankInferenceId, super.doHashCode()); + } + + public static HybridRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { + return PARSER.apply(parser, context); + } + + private static Map> copyQuerySettingsMap(Map> querySettingsMap) { + if (querySettingsMap == null) { + return Map.of(); + } + + ImmutableOpenMap.Builder> copyBuilder = new ImmutableOpenMap.Builder<>(querySettingsMap.size()); + for (var entry : querySettingsMap.entrySet()) { + String field = entry.getKey(); + List querySettings = entry.getValue(); + + copyBuilder.put(field, querySettings != null ? List.copyOf(querySettings) : List.of()); + } + + return copyBuilder.build(); + } + + private static RetrieverBuilder generateRetrieverBuilder( + List fields, + String query, + Boolean rerank, + String rerankField, + String rerankInferenceId, + Map> querySettingsMap, + int rankWindowSize + ) { + FieldsAndsWeights fieldsAndsWeights = generateFieldsAndWeights(fields); + + LinearRetrieverBuilder linearRetrieverBuilder = new LinearRetrieverBuilder( + generateInnerRetrievers(fieldsAndsWeights.fields(), query, querySettingsMap), + rankWindowSize, + fieldsAndsWeights.weights(), + generateScoreNormalizers(fields) + ); + + RetrieverBuilder rootRetriever = linearRetrieverBuilder; + if (rerank != null && rerank) { + if (rerankField == null) { + throw new IllegalArgumentException("[" + RERANK_FIELD_FIELD.getPreferredName() + "] is required when reranking is enabled"); + } + + rootRetriever = new TextSimilarityRankRetrieverBuilder( + linearRetrieverBuilder, + rerankInferenceId, + query, + rerankField, + rankWindowSize, + false + ); + } + + return rootRetriever; + } + + private static List generateInnerRetrievers( + List fields, + String query, + Map> querySettingsMap + ) { + if (fields == null) { + return List.of(); + } + + List innerRetrievers = new ArrayList<>(); + for (String field : fields) { + List fieldQueryBuilders = new ArrayList<>(); + List fieldQuerySettings = querySettingsMap != null ? querySettingsMap.get(field) : null; + if (fieldQuerySettings == null || fieldQuerySettings.isEmpty()) { + // Default to match query + fieldQueryBuilders.add(new MatchQueryBuilder(field, query)); + } else { + for (QuerySettings querySettings : fieldQuerySettings) { + fieldQueryBuilders.add(querySettings.constructQueryBuilder(field, query)); + } + } + + for (QueryBuilder queryBuilder : fieldQueryBuilders) { + innerRetrievers.add(new CompoundRetrieverBuilder.RetrieverSource(new StandardRetrieverBuilder(queryBuilder), null)); + } + } + + return innerRetrievers; + } + + private static FieldsAndsWeights generateFieldsAndWeights(List fields) { + if (fields == null) { + return new FieldsAndsWeights(List.of(), new float[0]); + } + + int fieldCount = fields.size(); + List parsedFields = new ArrayList<>(fieldCount); + float[] parsedWeights = new float[fieldCount]; + for (int i = 0; i < fieldCount; i++) { + String[] fieldSplit = fields.get(i).split("\\^"); + + float weight = 1.0f; + if (fieldSplit.length > 2) { + throw new IllegalArgumentException("Invalid field name [" + fields.get(i) + "]"); + } else if (fieldSplit.length == 2) { + weight = Float.parseFloat(fieldSplit[1]); + } + + parsedFields.add(fieldSplit[0]); + parsedWeights[i] = weight; + } + + return new FieldsAndsWeights(Collections.unmodifiableList(parsedFields), parsedWeights); + } + + private static ScoreNormalizer[] generateScoreNormalizers(List fields) { + if (fields == null) { + return new ScoreNormalizer[0]; + } + + ScoreNormalizer[] scoreNormalizers = new ScoreNormalizer[fields.size()]; + Arrays.fill(scoreNormalizers, new MinMaxScoreNormalizer(0)); + return scoreNormalizers; + } + + private record FieldsAndsWeights(List fields, float[] weights) {} + + // TODO: Probably a better way to do this, but this is quick & dirty for POC purposes + private static QuerySettings parseQuerySettings(Map map) { + Map querySettingsMap = XContentMapValues.nodeMapValue(map, "query settings"); + + Object typeObject = querySettingsMap.get(TYPE_FIELD.getPreferredName()); + if (typeObject == null) { + throw new IllegalArgumentException("[" + TYPE_FIELD.getPreferredName() + "] must be provided in query settings"); + } else if (typeObject instanceof String == false) { + throw new IllegalArgumentException("[" + TYPE_FIELD.getPreferredName() + "] must have a string value"); + } + + String typeString = (String) typeObject; + MapXContentParser mapXContentParser = new MapXContentParser( + NAMED_X_CONTENT_REGISTRY, + LoggingDeprecationHandler.INSTANCE, + querySettingsMap, + null + ); + + try (mapXContentParser) { + return mapXContentParser.namedObject(QuerySettings.class, typeString, null); + } catch (IOException e) { + throw new XContentParseException(mapXContentParser.getTokenLocation(), "Failed to parse query settings"); + } + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/MatchPhraseQuerySettings.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/MatchPhraseQuerySettings.java new file mode 100644 index 0000000000000..2907850fe3686 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/MatchPhraseQuerySettings.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.hybrid; + +import org.elasticsearch.index.query.MatchPhraseQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class MatchPhraseQuerySettings implements QuerySettings { + public static final QueryType QUERY_TYPE = QueryType.MATCH_PHRASE; + + public static final ParseField SLOP_FIELD = new ParseField("slop"); + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "match_phrase_query_settings", + false, + args -> { + String typeString = (String) args[0]; + Integer slop = (Integer) args[1]; + + QueryType queryType = QueryType.fromString(typeString); + if (queryType != QUERY_TYPE) { + throw new IllegalStateException("Query type must be " + QUERY_TYPE); + } + + return new MatchPhraseQuerySettings(slop); + } + ); + + static { + PARSER.declareString(constructorArg(), TYPE_FIELD); + PARSER.declareInt(optionalConstructorArg(), SLOP_FIELD); + } + + private final Integer slop; + + public MatchPhraseQuerySettings(Integer slop) { + this.slop = slop; + } + + public Integer getSlop() { + return slop; + } + + @Override + public QueryType getQueryType() { + return QUERY_TYPE; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TYPE_FIELD.getPreferredName(), getName()); + if (slop != null) { + builder.field(SLOP_FIELD.getPreferredName(), slop); + } + builder.endObject(); + + return builder; + } + + @Override + public QueryBuilder constructQueryBuilder(String field, String query) { + MatchPhraseQueryBuilder queryBuilder = new MatchPhraseQueryBuilder(field, query); + if (slop != null) { + queryBuilder.slop(getSlop()); + } + + return queryBuilder; + } + + public static MatchPhraseQuerySettings fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/MatchQuerySettings.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/MatchQuerySettings.java new file mode 100644 index 0000000000000..e9d5beaf236b4 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/MatchQuerySettings.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.hybrid; + +import org.elasticsearch.index.query.MatchQueryBuilder; +import org.elasticsearch.index.query.Operator; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class MatchQuerySettings implements QuerySettings { + public static final QueryType QUERY_TYPE = QueryType.MATCH; + + public static final ParseField OPERATOR_FIELD = new ParseField("operator"); + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "match_query_settings", + false, + args -> { + String typeString = (String) args[0]; + String operatorString = (String) args[1]; + + QueryType queryType = QueryType.fromString(typeString); + if (queryType != QueryType.MATCH) { + throw new IllegalStateException("Query type must be " + QueryType.MATCH); + } + + return new MatchQuerySettings( + operatorString != null ? Operator.fromString(operatorString) : MatchQueryBuilder.DEFAULT_OPERATOR + ); + } + ); + + static { + PARSER.declareString(constructorArg(), TYPE_FIELD); + PARSER.declareString(optionalConstructorArg(), OPERATOR_FIELD); + } + + private final Operator operator; + + public MatchQuerySettings(Operator operator) { + this.operator = operator; + } + + public Operator getOperator() { + return operator; + } + + @Override + public QueryType getQueryType() { + return QUERY_TYPE; + } + + @Override + public QueryBuilder constructQueryBuilder(String field, String query) { + return new MatchQueryBuilder(field, query).operator(getOperator()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TYPE_FIELD.getPreferredName(), getName()); + builder.field(OPERATOR_FIELD.getPreferredName(), operator); + builder.endObject(); + + return builder; + } + + public static MatchQuerySettings fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/QuerySettings.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/QuerySettings.java new file mode 100644 index 0000000000000..9061af538737f --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/QuerySettings.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.hybrid; + +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + +public interface QuerySettings extends NamedXContentObject { + ParseField TYPE_FIELD = new ParseField("type"); + + QueryType getQueryType(); + + QueryBuilder constructQueryBuilder(String field, String query); + + @Override + default String getName() { + return getQueryType().getQueryName(); + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/QueryType.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/QueryType.java new file mode 100644 index 0000000000000..af719efb5af01 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/hybrid/QueryType.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.hybrid; + +import java.util.Locale; + +public enum QueryType { + MATCH, + MATCH_PHRASE; + + public String getQueryName() { + return name().toLowerCase(Locale.ROOT); + } + + public static QueryType fromString(String queryName) { + return valueOf(queryName.toUpperCase(Locale.ROOT)); + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java index 436096523a1ec..2058c1b4c2986 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -205,4 +205,6 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept } builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); } + + // TODO: Need doEquals & doHashCode to check weights and normalizers } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java index 56b42b48a5d47..7ca27eef28717 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java @@ -17,7 +17,15 @@ public class MinMaxScoreNormalizer extends ScoreNormalizer { private static final float EPSILON = 1e-6f; - public MinMaxScoreNormalizer() {} + private final float initialMin; + + public MinMaxScoreNormalizer() { + this(Float.MAX_VALUE); + } + + public MinMaxScoreNormalizer(float initialMin) { + this.initialMin = initialMin; + } @Override public String getName() { @@ -31,7 +39,7 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { } // create a new array to avoid changing ScoreDocs in place ScoreDoc[] scoreDocs = new ScoreDoc[docs.length]; - float min = Float.MAX_VALUE; + float min = initialMin; float max = Float.MIN_VALUE; boolean atLeastOneValidScore = false; for (ScoreDoc rd : docs) { diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java index 251015b21ff50..5cd9dcaad5339 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java @@ -17,6 +17,7 @@ import org.elasticsearch.search.rank.RankShardResult; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xpack.rank.hybrid.HybridRetrieverBuilder; import org.elasticsearch.xpack.rank.linear.LinearRankDoc; import org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder; @@ -57,7 +58,8 @@ public List getNamedXContent() { public List> getRetrievers() { return List.of( new RetrieverSpec<>(new ParseField(NAME), RRFRetrieverBuilder::fromXContent), - new RetrieverSpec<>(new ParseField(LinearRetrieverBuilder.NAME), LinearRetrieverBuilder::fromXContent) + new RetrieverSpec<>(new ParseField(LinearRetrieverBuilder.NAME), LinearRetrieverBuilder::fromXContent), + new RetrieverSpec<>(new ParseField(HybridRetrieverBuilder.NAME), HybridRetrieverBuilder::fromXContent) ); } }