Skip to content

Commit 34250dc

Browse files
committed
Validation improvements
1 parent a777601 commit 34250dc

File tree

3 files changed

+139
-41
lines changed

3 files changed

+139
-41
lines changed

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.rank.linear;
99

1010
import org.apache.lucene.search.ScoreDoc;
11+
import org.elasticsearch.action.ActionRequestValidationException;
1112
import org.elasticsearch.action.ResolvedIndices;
1213
import org.elasticsearch.common.ParsingException;
1314
import org.elasticsearch.common.util.Maps;
@@ -32,8 +33,10 @@
3233
import java.util.ArrayList;
3334
import java.util.Arrays;
3435
import java.util.List;
36+
import java.util.Locale;
3537
import java.util.Map;
3638

39+
import static org.elasticsearch.action.ValidateActions.addValidationError;
3740
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
3841
import static org.elasticsearch.xpack.rank.RankRRFFeatures.LINEAR_RETRIEVER_SUPPORTED;
3942
import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT;
@@ -63,7 +66,7 @@ public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder<Linea
6366
private final ScoreNormalizer[] normalizers;
6467
private final List<String> fields;
6568
private final String query;
66-
private final String normalizer;
69+
private final ScoreNormalizer normalizer;
6770

6871
@SuppressWarnings("unchecked")
6972
static final ConstructingObjectParser<LinearRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
@@ -73,7 +76,7 @@ public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder<Linea
7376
List<LinearRetrieverComponent> retrieverComponents = args[0] == null ? List.of() : (List<LinearRetrieverComponent>) args[0];
7477
List<String> fields = (List<String>) args[1];
7578
String query = (String) args[2];
76-
String normalizer = (String) args[3];
79+
ScoreNormalizer normalizer = args[3] == null ? null : ScoreNormalizer.valueOf((String) args[3]);
7780
int rankWindowSize = args[4] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[4];
7881

7982
int index = 0;
@@ -140,12 +143,12 @@ public LinearRetrieverBuilder(
140143
List<RetrieverSource> innerRetrievers,
141144
List<String> fields,
142145
String query,
143-
String normalizer,
146+
ScoreNormalizer normalizer,
144147
int rankWindowSize,
145148
float[] weights,
146149
ScoreNormalizer[] normalizers
147150
) {
148-
// Use a mutable list for innerRetrievers so that we can add more child retrievers during rewrite
151+
// Use a mutable list for innerRetrievers so that we can use addChild
149152
super(innerRetrievers == null ? new ArrayList<>() : new ArrayList<>(innerRetrievers), rankWindowSize);
150153
if (weights.length != this.innerRetrievers.size()) {
151154
throw new IllegalArgumentException("The number of weights must match the number of inner retrievers");
@@ -159,6 +162,55 @@ public LinearRetrieverBuilder(
159162
this.normalizer = normalizer;
160163
this.weights = weights;
161164
this.normalizers = normalizers;
165+
166+
// TODO: Validate simplified query format args here?
167+
// Otherwise some of the validation is skipped when creating the retriever programmatically.
168+
}
169+
170+
@Override
171+
public ActionRequestValidationException validate(
172+
SearchSourceBuilder source,
173+
ActionRequestValidationException validationException,
174+
boolean isScroll,
175+
boolean allowPartialSearchResults
176+
) {
177+
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
178+
validationException = SimplifiedInnerRetrieverUtils.validateSimplifiedFormatParams(
179+
innerRetrievers,
180+
fields,
181+
query,
182+
getName(),
183+
RETRIEVERS_FIELD.getPreferredName(),
184+
FIELDS_FIELD.getPreferredName(),
185+
QUERY_FIELD.getPreferredName(),
186+
validationException
187+
);
188+
189+
if (query != null && normalizer == null) {
190+
validationException = addValidationError(
191+
String.format(
192+
Locale.ROOT,
193+
"[%s] [%s] must be provided when [%s] is specified",
194+
getName(),
195+
NORMALIZER_FIELD.getPreferredName(),
196+
QUERY_FIELD.getPreferredName()
197+
),
198+
validationException
199+
);
200+
} else if (innerRetrievers.isEmpty() == false && normalizer != null) {
201+
validationException = addValidationError(
202+
String.format(
203+
Locale.ROOT,
204+
"[%s] [%s] cannot be provided when [%s] is specified",
205+
getName(),
206+
NORMALIZER_FIELD.getPreferredName(),
207+
RETRIEVERS_FIELD.getPreferredName()
208+
),
209+
validationException
210+
);
211+
}
212+
213+
return validationException;
162214
}
163215

164216
@Override
@@ -233,27 +285,8 @@ protected LinearRetrieverBuilder doRewrite(QueryRewriteContext ctx) {
233285
LinearRetrieverBuilder rewritten = this;
234286

235287
ResolvedIndices resolvedIndices = ctx.getResolvedIndices();
236-
if (resolvedIndices != null && (query != null || fields.isEmpty() == false)) {
288+
if (resolvedIndices != null && query != null) {
237289
// Using the simplified query format
238-
if (query == null || query.isEmpty()) {
239-
throw new IllegalArgumentException(
240-
"[" + NAME + "] [" + QUERY_FIELD.getPreferredName() + "] must be provided when using the simplified query format"
241-
);
242-
}
243-
244-
if (normalizer == null || normalizer.isEmpty()) {
245-
throw new IllegalArgumentException(
246-
"[" + NAME + "] [" + NORMALIZER_FIELD.getPreferredName() + "] must be provided when using the simplified query format"
247-
);
248-
}
249-
ScoreNormalizer fieldsNormalizer = ScoreNormalizer.valueOf(normalizer);
250-
251-
if (innerRetrievers.isEmpty() == false) {
252-
throw new IllegalArgumentException(
253-
"[" + NAME + "] does not support [" + RETRIEVERS_FIELD.getPreferredName() + "] and the simplified query format combined"
254-
);
255-
}
256-
257290
var localIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata();
258291
if (localIndicesMetadata.size() > 1) {
259292
throw new IllegalArgumentException(
@@ -274,7 +307,7 @@ protected LinearRetrieverBuilder doRewrite(QueryRewriteContext ctx) {
274307
for (var weightedRetriever : r) {
275308
retrievers.add(weightedRetriever.retrieverSource());
276309
weights[index] = weightedRetriever.weight();
277-
normalizers[index] = fieldsNormalizer;
310+
normalizers[index] = normalizer;
278311
index++;
279312
}
280313

@@ -291,7 +324,7 @@ protected LinearRetrieverBuilder doRewrite(QueryRewriteContext ctx) {
291324
Arrays.fill(weights, DEFAULT_WEIGHT);
292325

293326
ScoreNormalizer[] normalizers = new ScoreNormalizer[fieldsInnerRetrievers.size()];
294-
Arrays.fill(normalizers, fieldsNormalizer);
327+
Arrays.fill(normalizers, normalizer);
295328

296329
rewritten = new LinearRetrieverBuilder(fieldsInnerRetrievers, null, null, normalizer, rankWindowSize, weights, normalizers);
297330
}
@@ -330,7 +363,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
330363
builder.field(QUERY_FIELD.getPreferredName(), query);
331364
}
332365
if (normalizer != null) {
333-
builder.field(NORMALIZER_FIELD.getPreferredName(), normalizer);
366+
builder.field(NORMALIZER_FIELD.getPreferredName(), normalizer.getName());
334367
}
335368

336369
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
package org.elasticsearch.xpack.rank.rrf;
99

1010
import org.apache.lucene.search.ScoreDoc;
11+
import org.elasticsearch.action.ActionRequestValidationException;
1112
import org.elasticsearch.action.ResolvedIndices;
1213
import org.elasticsearch.common.util.Maps;
1314
import org.elasticsearch.index.query.QueryBuilder;
1415
import org.elasticsearch.index.query.QueryRewriteContext;
1516
import org.elasticsearch.license.LicenseUtils;
17+
import org.elasticsearch.search.builder.SearchSourceBuilder;
1618
import org.elasticsearch.search.rank.RankBuilder;
1719
import org.elasticsearch.search.rank.RankDoc;
1820
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
@@ -105,18 +107,41 @@ public RRFRetrieverBuilder(int rankWindowSize, int rankConstant) {
105107
}
106108

107109
RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, List<String> fields, String query, int rankWindowSize, int rankConstant) {
108-
// Use a mutable list for childRetrievers so that we can add more child retrievers during rewrite
110+
// Use a mutable list for childRetrievers so that we can use addChild
109111
super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize);
110112
this.fields = fields == null ? List.of() : List.copyOf(fields);
111113
this.query = query;
112114
this.rankConstant = rankConstant;
115+
116+
// TODO: Validate simplified query format args here?
117+
// Otherwise some of the validation is skipped when creating the retriever programmatically.
113118
}
114119

115120
@Override
116121
public String getName() {
117122
return NAME;
118123
}
119124

125+
@Override
126+
public ActionRequestValidationException validate(
127+
SearchSourceBuilder source,
128+
ActionRequestValidationException validationException,
129+
boolean isScroll,
130+
boolean allowPartialSearchResults
131+
) {
132+
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
133+
return SimplifiedInnerRetrieverUtils.validateSimplifiedFormatParams(
134+
innerRetrievers,
135+
fields,
136+
query,
137+
getName(),
138+
RETRIEVERS_FIELD.getPreferredName(),
139+
FIELDS_FIELD.getPreferredName(),
140+
QUERY_FIELD.getPreferredName(),
141+
validationException
142+
);
143+
}
144+
120145
@Override
121146
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
122147
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant);
@@ -187,20 +212,8 @@ protected RRFRetrieverBuilder doRewrite(QueryRewriteContext ctx) {
187212
RRFRetrieverBuilder rewritten = this;
188213

189214
ResolvedIndices resolvedIndices = ctx.getResolvedIndices();
190-
if (resolvedIndices != null && (query != null || fields.isEmpty() == false)) {
215+
if (resolvedIndices != null && query != null) {
191216
// Using the simplified query format
192-
if (query == null || query.isEmpty()) {
193-
throw new IllegalArgumentException(
194-
"[" + NAME + "] [" + QUERY_FIELD.getPreferredName() + "] must be provided when using the simplified query format"
195-
);
196-
}
197-
198-
if (innerRetrievers.isEmpty() == false) {
199-
throw new IllegalArgumentException(
200-
"[" + NAME + "] does not support [" + RETRIEVERS_FIELD.getPreferredName() + "] and the simplified query format combined"
201-
);
202-
}
203-
204217
var localIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata();
205218
if (localIndicesMetadata.size() > 1) {
206219
throw new IllegalArgumentException(

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/simplified/SimplifiedInnerRetrieverUtils.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.rank.simplified;
99

10+
import org.elasticsearch.action.ActionRequestValidationException;
1011
import org.elasticsearch.cluster.metadata.IndexMetadata;
1112
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
1213
import org.elasticsearch.common.regex.Regex;
@@ -23,17 +24,68 @@
2324
import java.util.Collection;
2425
import java.util.HashMap;
2526
import java.util.List;
27+
import java.util.Locale;
2628
import java.util.Map;
2729
import java.util.function.Consumer;
2830
import java.util.function.Function;
2931

32+
import static org.elasticsearch.action.ValidateActions.addValidationError;
3033
import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING;
3134

3235
public class SimplifiedInnerRetrieverUtils {
3336
private SimplifiedInnerRetrieverUtils() {}
3437

3538
public record WeightedRetrieverSource(CompoundRetrieverBuilder.RetrieverSource retrieverSource, float weight) {}
3639

40+
public static ActionRequestValidationException validateSimplifiedFormatParams(
41+
List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers,
42+
List<String> fields,
43+
@Nullable String query,
44+
String retrieverName,
45+
String retrieversParamName,
46+
String fieldsParamName,
47+
String queryParamName,
48+
ActionRequestValidationException validationException
49+
) {
50+
if (fields.isEmpty() == false || query != null) {
51+
// Using the simplified query format
52+
if (query == null) {
53+
// Return early here because the following validation checks assume a query param value is provided
54+
return addValidationError(
55+
String.format(
56+
Locale.ROOT,
57+
"[%s] [%s] must be provided when [%s] is specified",
58+
retrieverName,
59+
queryParamName,
60+
fieldsParamName
61+
),
62+
validationException
63+
);
64+
}
65+
66+
if (query.isEmpty()) {
67+
validationException = addValidationError(
68+
String.format(Locale.ROOT, "[%s] [%s] cannot be empty", retrieverName, queryParamName),
69+
validationException
70+
);
71+
}
72+
73+
if (innerRetrievers.isEmpty() == false) {
74+
validationException = addValidationError(
75+
String.format(Locale.ROOT, "[%s] cannot combine [%s] and [%s]", retrieverName, retrieversParamName, queryParamName),
76+
validationException
77+
);
78+
}
79+
} else if (innerRetrievers.isEmpty()) {
80+
validationException = addValidationError(
81+
String.format(Locale.ROOT, "[%s] must provide [%s] or [%s]", retrieverName, retrieversParamName, queryParamName),
82+
validationException
83+
);
84+
}
85+
86+
return validationException;
87+
}
88+
3789
public static List<RetrieverBuilder> generateInnerRetrievers(
3890
List<String> fieldsAndWeights,
3991
String query,

0 commit comments

Comments
 (0)