Skip to content

Commit b6d5109

Browse files
committed
WIP
1 parent cdef674 commit b6d5109

File tree

2 files changed

+276
-26
lines changed

2 files changed

+276
-26
lines changed

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,34 @@ private static ScoreNormalizer[] getDefaultNormalizers(List<RetrieverSource> inn
123123
return normalizers;
124124
}
125125

126+
private void normalizeNormalizerArray(ScoreNormalizer topLevelNormalizer, ScoreNormalizer[] normalizers) {
127+
for (int i = 0; i < normalizers.length; i++) {
128+
ScoreNormalizer current = normalizers[i];
129+
130+
if (topLevelNormalizer != null) {
131+
// Validate explicit per-retriever normalizers match top-level
132+
if (current != null && !current.equals(DEFAULT_NORMALIZER) && !current.equals(topLevelNormalizer)) {
133+
throw new IllegalArgumentException(
134+
String.format(
135+
"[%s] All per-retriever normalizers must match the top-level normalizer: "
136+
+ "expected [%s], found [%s] in retriever [%d]",
137+
NAME, topLevelNormalizer.getName(), current.getName(), i
138+
)
139+
);
140+
}
141+
// Propagate top-level normalizer to unspecified positions
142+
if (current == null || current.equals(DEFAULT_NORMALIZER)) {
143+
normalizers[i] = topLevelNormalizer;
144+
}
145+
} else {
146+
// No top-level normalizer: ensure null values become DEFAULT_NORMALIZER
147+
if (current == null) {
148+
normalizers[i] = DEFAULT_NORMALIZER;
149+
}
150+
}
151+
}
152+
}
153+
126154
public static LinearRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
127155
if (context.clusterSupportsFeature(LINEAR_RETRIEVER_SUPPORTED) == false) {
128156
throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + NAME + "]");
@@ -185,32 +213,7 @@ public LinearRetrieverBuilder(
185213
this.query = query;
186214
this.normalizer = normalizer;
187215

188-
if (normalizer != null) {
189-
// First pass: validate that any specified per-retriever normalizers match the top-level one
190-
for (int i = 0; i < normalizers.length; i++) {
191-
ScoreNormalizer subNormalizer = normalizers[i];
192-
if (subNormalizer != null && !subNormalizer.equals(DEFAULT_NORMALIZER) && !subNormalizer.equals(normalizer)) {
193-
throw new IllegalArgumentException(
194-
"["
195-
+ NAME
196-
+ "] All per-retriever normalizers must match the top-level normalizer: "
197-
+ "expected ["
198-
+ normalizer.getName()
199-
+ "], found ["
200-
+ subNormalizer.getName()
201-
+ "] in retriever ["
202-
+ i
203-
+ "]"
204-
);
205-
}
206-
}
207-
// Second pass: propagate top-level normalizer to any unspecified positions
208-
for (int i = 0; i < normalizers.length; i++) {
209-
if (normalizers[i] == null || normalizers[i].equals(DEFAULT_NORMALIZER)) {
210-
normalizers[i] = normalizer;
211-
}
212-
}
213-
}
216+
normalizeNormalizerArray(normalizer, normalizers);
214217

215218
}
216219

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
setup:
2+
- requires:
3+
cluster_features: [ "linear_retriever_supported", "linear_retriever.l2_norm" ]
4+
reason: "Support for linear retriever and L2 normalization"
5+
test_runner_features: close_to
6+
7+
- do:
8+
indices.create:
9+
index: test
10+
body:
11+
mappings:
12+
properties:
13+
vector:
14+
type: dense_vector
15+
dims: 1
16+
index: true
17+
similarity: l2_norm
18+
index_options:
19+
type: flat
20+
keyword:
21+
type: keyword
22+
other_keyword:
23+
type: keyword
24+
timestamp:
25+
type: date
26+
27+
- do:
28+
bulk:
29+
refresh: true
30+
index: test
31+
body:
32+
- '{"index": {"_id": 1 }}'
33+
- '{"vector": [1], "keyword": "one", "other_keyword": "other", "timestamp": "2021-01-01T00:00:00"}'
34+
- '{"index": {"_id": 2 }}'
35+
- '{"vector": [2], "keyword": "two", "timestamp": "2022-01-01T00:00:00"}'
36+
- '{"index": {"_id": 3 }}'
37+
- '{"vector": [3], "keyword": "three", "timestamp": "2023-01-01T00:00:00"}'
38+
- '{"index": {"_id": 4 }}'
39+
- '{"vector": [4], "keyword": "four", "other_keyword": "other", "timestamp": "2024-01-01T00:00:00"}'
40+
41+
---
42+
"Linear retriever with top-level L2 normalization":
43+
- do:
44+
search:
45+
index: test
46+
body:
47+
retriever:
48+
linear:
49+
normalizer: l2_norm
50+
retrievers: [
51+
{
52+
retriever: {
53+
standard: {
54+
query: {
55+
constant_score: {
56+
filter: {
57+
term: {
58+
keyword: {
59+
value: "one"
60+
}
61+
}
62+
},
63+
boost: 5.0
64+
}
65+
}
66+
}
67+
},
68+
weight: 1.0
69+
},
70+
{
71+
retriever: {
72+
standard: {
73+
query: {
74+
constant_score: {
75+
filter: {
76+
term: {
77+
keyword: {
78+
value: "four"
79+
}
80+
}
81+
},
82+
boost: 12.0
83+
}
84+
}
85+
}
86+
},
87+
weight: 1.0
88+
}
89+
]
90+
91+
- match: { hits.total.value: 2 }
92+
- match: { hits.hits.0._id: "4" } # Doc 4 should rank higher with normalized scores
93+
- match: { hits.hits.1._id: "1" }
94+
# With L2 normalization: [5.0, 12.0] becomes [5.0/13.0, 12.0/13.0]
95+
- close_to: { hits.hits.0._score: { value: 0.923, error: 0.01} } # 12.0/13.0
96+
- close_to: { hits.hits.1._score: { value: 0.385, error: 0.01} } # 5.0/13.0
97+
98+
---
99+
"Linear retriever with per-retriever L2 normalization":
100+
- do:
101+
search:
102+
index: test
103+
body:
104+
retriever:
105+
linear:
106+
retrievers: [
107+
{
108+
retriever: {
109+
standard: {
110+
query: {
111+
constant_score: {
112+
filter: {
113+
term: {
114+
keyword: {
115+
value: "one"
116+
}
117+
}
118+
},
119+
boost: 5.0
120+
}
121+
}
122+
}
123+
},
124+
weight: 1.0,
125+
normalizer: l2_norm
126+
},
127+
{
128+
retriever: {
129+
standard: {
130+
query: {
131+
constant_score: {
132+
filter: {
133+
term: {
134+
keyword: {
135+
value: "four"
136+
}
137+
}
138+
},
139+
boost: 12.0
140+
}
141+
}
142+
}
143+
},
144+
weight: 1.0,
145+
normalizer: l2_norm
146+
}
147+
]
148+
149+
- match: { hits.total.value: 2 }
150+
# With per-retriever L2 normalization, both scores would be normalized to 1.0
151+
# So final score = 1.0 * weight1 + 1.0 * weight2 = 2.0 for each doc
152+
# Then sorting is done by _doc (or some other tiebreaker)
153+
- close_to: { hits.hits.0._score: { value: 1.0, error: 0.01} }
154+
- close_to: { hits.hits.1._score: { value: 1.0, error: 0.01} }
155+
156+
---
157+
"Linear retriever with mixed normalization (top-level and per-retriever with same normalizer)":
158+
- do:
159+
search:
160+
index: test
161+
body:
162+
retriever:
163+
linear:
164+
normalizer: l2_norm
165+
retrievers: [
166+
{
167+
retriever: {
168+
standard: {
169+
query: {
170+
constant_score: {
171+
filter: {
172+
term: {
173+
keyword: {
174+
value: "one"
175+
}
176+
}
177+
},
178+
boost: 5.0
179+
}
180+
}
181+
}
182+
},
183+
weight: 1.0
184+
},
185+
{
186+
retriever: {
187+
standard: {
188+
query: {
189+
constant_score: {
190+
filter: {
191+
term: {
192+
keyword: {
193+
value: "four"
194+
}
195+
}
196+
},
197+
boost: 12.0
198+
}
199+
}
200+
}
201+
},
202+
weight: 1.0,
203+
normalizer: l2_norm
204+
}
205+
]
206+
207+
- match: { hits.total.value: 2 }
208+
- match: { hits.hits.0._id: "4" }
209+
- match: { hits.hits.1._id: "1" }
210+
# With L2 normalization: [5.0, 12.0] becomes [5.0/13.0, 12.0/13.0]
211+
- close_to: { hits.hits.0._score: { value: 0.923, error: 0.01} }
212+
- close_to: { hits.hits.1._score: { value: 0.385, error: 0.01} }
213+
214+
---
215+
"Linear retriever with mismatched normalizers (should fail)":
216+
- do:
217+
catch: bad_request
218+
search:
219+
index: test
220+
body:
221+
retriever:
222+
linear:
223+
normalizer: l2_norm
224+
retrievers: [
225+
{
226+
retriever: {
227+
standard: {
228+
query: {
229+
match_all: {}
230+
}
231+
}
232+
}
233+
},
234+
{
235+
retriever: {
236+
standard: {
237+
query: {
238+
match_all: {}
239+
}
240+
}
241+
},
242+
normalizer: minmax
243+
}
244+
]
245+
246+
- match: { error.root_cause.0.type: "illegal_argument_exception" }
247+
- match: { error.root_cause.0.reason: /.*All per-retriever normalizers must match the top-level normalizer.*/ }

0 commit comments

Comments
 (0)