Skip to content

Commit 1ac18fe

Browse files
committed
propagate
1 parent b447b33 commit 1ac18fe

File tree

6 files changed

+29
-18
lines changed

6 files changed

+29
-18
lines changed

server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ public static FieldBasedRankBuilder fromXContent(XContentParser parser) throws I
100100
}
101101

102102
public FieldBasedRankBuilder(final int rankWindowSize, final String field) {
103-
super(rankWindowSize);
103+
super(rankWindowSize, false);
104104
this.field = field;
105105
}
106106

@@ -205,7 +205,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
205205

206206
@Override
207207
public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) {
208-
return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize()) {
208+
return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize(), isLenient()) {
209209
@Override
210210
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
211211
float[] scores = new float[featureDocs.length];
@@ -346,7 +346,7 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
346346
@Override
347347
public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from, Client client) {
348348
if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT)
349-
return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize()) {
349+
return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize(), isLenient()) {
350350
@Override
351351
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
352352
throw new UnsupportedOperationException("rfc - simulated failure");

server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ public static class TestRerankingRankFeaturePhaseRankCoordinatorContext extends
249249
String inferenceText,
250250
float minScore
251251
) {
252-
super(size, from, windowSize);
252+
super(size, from, windowSize, false);
253253
this.client = client;
254254
this.inferenceId = inferenceId;
255255
this.inferenceText = inferenceText;
@@ -337,7 +337,7 @@ public MockRequestActionBasedRankBuilder(
337337
final String inferenceText,
338338
final float minScore
339339
) {
340-
super(rankWindowSize);
340+
super(rankWindowSize, false);
341341
this.field = field;
342342
this.inferenceId = inferenceId;
343343
this.inferenceText = inferenceText;

server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,15 @@ public final boolean equals(Object obj) {
152152
if (obj == null || getClass() != obj.getClass()) {
153153
return false;
154154
}
155-
@SuppressWarnings("unchecked")
156155
RankBuilder other = (RankBuilder) obj;
157-
return Objects.equals(rankWindowSize, other.rankWindowSize()) && doEquals(other);
156+
return rankWindowSize == other.rankWindowSize && lenient == other.lenient && doEquals(other);
158157
}
159158

160159
protected abstract boolean doEquals(RankBuilder other);
161160

162161
@Override
163162
public final int hashCode() {
164-
return Objects.hash(getClass(), rankWindowSize, doHashCode());
163+
return Objects.hash(getClass(), rankWindowSize, lenient, doHashCode());
165164
}
166165

167166
protected abstract int doHashCode();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
4444
public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
4545
public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text");
4646
public static final ParseField FIELD_FIELD = new ParseField("field");
47+
public static final ParseField LENIENT_FIELD = new ParseField("lenient");
4748

4849
public static final ConstructingObjectParser<TextSimilarityRankRetrieverBuilder, RetrieverParserContext> PARSER =
4950
new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> {
@@ -54,7 +55,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
5455
int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4];
5556
boolean lenient = args[5] != null && (Boolean) args[5];
5657

57-
return new TextSimilarityRankRetrieverBuilder(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize);
58+
return new TextSimilarityRankRetrieverBuilder(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize, lenient);
5859
});
5960

6061
static {
@@ -67,6 +68,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
6768
PARSER.declareString(constructorArg(), INFERENCE_TEXT_FIELD);
6869
PARSER.declareString(constructorArg(), FIELD_FIELD);
6970
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
71+
PARSER.declareBoolean(optionalConstructorArg(), LENIENT_FIELD);
7072

7173
RetrieverBuilder.declareBaseParserFields(TextSimilarityRankBuilder.NAME, PARSER);
7274
}
@@ -85,18 +87,21 @@ public static TextSimilarityRankRetrieverBuilder fromXContent(
8587
private final String inferenceId;
8688
private final String inferenceText;
8789
private final String field;
90+
private final boolean lenient;
8891

8992
public TextSimilarityRankRetrieverBuilder(
9093
RetrieverBuilder retrieverBuilder,
9194
String inferenceId,
9295
String inferenceText,
9396
String field,
94-
int rankWindowSize
97+
int rankWindowSize,
98+
boolean lenient
9599
) {
96100
super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize);
97101
this.inferenceId = inferenceId;
98102
this.inferenceText = inferenceText;
99103
this.field = field;
104+
this.lenient = lenient;
100105
}
101106

102107
public TextSimilarityRankRetrieverBuilder(
@@ -106,6 +111,7 @@ public TextSimilarityRankRetrieverBuilder(
106111
String field,
107112
int rankWindowSize,
108113
Float minScore,
114+
boolean lenient,
109115
String retrieverName,
110116
List<QueryBuilder> preFilterQueryBuilders
111117
) {
@@ -117,6 +123,7 @@ public TextSimilarityRankRetrieverBuilder(
117123
this.inferenceText = inferenceText;
118124
this.field = field;
119125
this.minScore = minScore;
126+
this.lenient = lenient;
120127
this.retrieverName = retrieverName;
121128
this.preFilterQueryBuilders = preFilterQueryBuilders;
122129
}
@@ -133,6 +140,7 @@ protected TextSimilarityRankRetrieverBuilder clone(
133140
field,
134141
rankWindowSize,
135142
minScore,
143+
lenient,
136144
retrieverName,
137145
newPreFilterQueryBuilders
138146
);
@@ -163,9 +171,7 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b
163171

164172
@Override
165173
protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
166-
sourceBuilder.rankBuilder(
167-
new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore, false)
168-
);
174+
sourceBuilder.rankBuilder(new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, lenient));
169175
return sourceBuilder;
170176
}
171177

@@ -189,6 +195,9 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc
189195
builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText);
190196
builder.field(FIELD_FIELD.getPreferredName(), field);
191197
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
198+
if (lenient) {
199+
builder.field(LENIENT_FIELD.getPreferredName(), lenient);
200+
}
192201
}
193202

194203
@Override
@@ -198,12 +207,13 @@ public boolean doEquals(Object other) {
198207
&& Objects.equals(inferenceId, that.inferenceId)
199208
&& Objects.equals(inferenceText, that.inferenceText)
200209
&& Objects.equals(field, that.field)
201-
&& Objects.equals(rankWindowSize, that.rankWindowSize)
202-
&& Objects.equals(minScore, that.minScore);
210+
&& rankWindowSize == that.rankWindowSize
211+
&& Objects.equals(minScore, that.minScore)
212+
&& lenient == that.lenient;
203213
}
204214

205215
@Override
206216
public int doHashCode() {
207-
return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore);
217+
return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, lenient);
208218
}
209219
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ public static TextSimilarityRankRetrieverBuilder createRandomTextSimilarityRankR
5757
randomAlphaOfLength(10),
5858
randomAlphaOfLength(20),
5959
randomAlphaOfLength(50),
60-
randomIntBetween(100, 10000)
60+
randomIntBetween(100, 10000),
61+
randomBoolean()
6162
);
6263
}
6364

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ public void testTelemetryForRRFRetriever() throws IOException {
139139
"some_inference_id",
140140
"some_inference_text",
141141
"some_field",
142-
10
142+
10,
143+
false
143144
)
144145
)
145146
);

0 commit comments

Comments
 (0)