Skip to content

Commit 330e32b

Browse files
committed
modified the builder
1 parent d36ada2 commit 330e32b

File tree

1 file changed

+23
-28
lines changed

1 file changed

+23
-28
lines changed

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ private static float[] getDefaultWeight(List<RetrieverSource> innerRetrievers) {
119119
private static ScoreNormalizer[] getDefaultNormalizers(List<RetrieverSource> innerRetrievers) {
120120
int size = innerRetrievers != null ? innerRetrievers.size() : 0;
121121
ScoreNormalizer[] normalizers = new ScoreNormalizer[size];
122-
return new ScoreNormalizer[size];
122+
Arrays.fill(normalizers, DEFAULT_NORMALIZER);
123+
return normalizers;
123124
}
124125

125126
public static LinearRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
@@ -185,16 +186,32 @@ public LinearRetrieverBuilder(
185186
this.normalizer = normalizer;
186187

187188
if (normalizer != null) {
188-
for (ScoreNormalizer subNormalizer : normalizers) {
189-
if (subNormalizer != null && subNormalizer.equals(DEFAULT_NORMALIZER) == false && subNormalizer.equals(normalizer) == false) {
189+
// First pass: validate that any specified per-retriever normalizers match the top-level one
190+
for (int i = 0; i < normalizers.length; i++) {
191+
ScoreNormalizer subNormalizer = normalizers[i];
192+
if (subNormalizer != null && !subNormalizer.equals(DEFAULT_NORMALIZER) && !subNormalizer.equals(normalizer)) {
190193
throw new IllegalArgumentException(
191-
"top-level normalizer ["
194+
"["
195+
+ NAME
196+
+ "] All per-retriever normalizers must match the top-level normalizer: "
197+
+ "expected ["
192198
+ normalizer.getName()
193-
+ "] is specified and it should be the same as all sub-retriever normalizers"
199+
+ "], found ["
200+
+ subNormalizer.getName()
201+
+ "] in retriever ["
202+
+ i
203+
+ "]"
194204
);
195205
}
196206
}
207+
// Second pass: propagate top-level normalizer to any unspecified positions
208+
for (int i = 0; i < normalizers.length; i++) {
209+
if (normalizers[i] == null || normalizers[i].equals(DEFAULT_NORMALIZER)) {
210+
normalizers[i] = normalizer;
211+
}
212+
}
197213
}
214+
198215
}
199216

200217
public LinearRetrieverBuilder(
@@ -250,28 +267,6 @@ public ActionRequestValidationException validate(
250267
);
251268
}
252269

253-
if (normalizer != null) {
254-
for (ScoreNormalizer perRetrieverNormalizer : normalizers) {
255-
boolean isExplicitSubNormalizer = perRetrieverNormalizer != null
256-
&& perRetrieverNormalizer.equals(DEFAULT_NORMALIZER) == false;
257-
boolean isMismatch = isExplicitSubNormalizer && perRetrieverNormalizer.equals(normalizer) == false;
258-
if (isMismatch) {
259-
validationException = addValidationError(
260-
String.format(
261-
Locale.ROOT,
262-
"[%s] top-level [%s] is [%s] but a sub-retriever specifies [%s]",
263-
getName(),
264-
NORMALIZER_FIELD.getPreferredName(),
265-
normalizer.getName(),
266-
perRetrieverNormalizer.getName()
267-
),
268-
validationException
269-
);
270-
break;
271-
}
272-
}
273-
}
274-
275270
return validationException;
276271
}
277272

@@ -418,7 +413,7 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
418413
linearRewritten.innerRetrievers,
419414
linearRewritten.fields,
420415
linearRewritten.query,
421-
null,
416+
normalizer,
422417
linearRewritten.rankWindowSize,
423418
linearRewritten.weights,
424419
newNormalizers

0 commit comments

Comments
 (0)