Skip to content

Commit 89c8998

Browse files
committed
Update to use singleton pattern
Signed-off-by: Atri Sharma <atri.jiit@gmail.com>
1 parent 4155cfe commit 89c8998

File tree

2 files changed

+135
-11
lines changed

2 files changed

+135
-11
lines changed

server/src/main/java/org/opensearch/search/query/QueryRewriterRegistry.java

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,66 @@
2121
import java.util.ArrayList;
2222
import java.util.Comparator;
2323
import java.util.List;
24+
import java.util.concurrent.CopyOnWriteArrayList;
2425

2526
/**
26-
* Registry for query rewriters.
27+
* Registry for query rewriters
2728
*
2829
* @opensearch.internal
2930
*/
3031
public final class QueryRewriterRegistry {
3132

3233
private static final Logger logger = LogManager.getLogger(QueryRewriterRegistry.class);
3334

34-
private QueryRewriterRegistry() {}
35+
private static final QueryRewriterRegistry INSTANCE = new QueryRewriterRegistry();
36+
37+
/**
38+
* Default rewriters.
39+
* CopyOnWriteArrayList is used for thread-safety during registration.
40+
*/
41+
private final CopyOnWriteArrayList<QueryRewriter> rewriters;
42+
43+
private QueryRewriterRegistry() {
44+
this.rewriters = new CopyOnWriteArrayList<>();
45+
46+
// Register default rewriters
47+
// Note: TermsMergingRewriter is special - it needs threshold at runtime
48+
registerRewriter(new BooleanFlatteningRewriter());
49+
registerRewriter(new MustToFilterRewriter());
50+
registerRewriter(new MustNotToShouldRewriter());
51+
registerRewriter(new MatchAllRemovalRewriter());
52+
}
53+
54+
/**
55+
* Get the singleton instance of the registry.
56+
*/
57+
public static QueryRewriterRegistry getInstance() {
58+
return INSTANCE;
59+
}
60+
61+
/**
62+
* Register a custom query rewriter.
63+
*
64+
* @param rewriter The rewriter to register
65+
*/
66+
public void registerRewriter(QueryRewriter rewriter) {
67+
if (rewriter != null) {
68+
rewriters.add(rewriter);
69+
logger.info("Registered query rewriter: {}", rewriter.name());
70+
}
71+
}
72+
73+
/**
74+
* Get a list of all rewriters with the given terms threshold.
75+
*/
76+
private List<QueryRewriter> getRewritersWithThreshold(int termsThreshold) {
77+
List<QueryRewriter> allRewriters = new ArrayList<>(rewriters);
78+
// Add TermsMergingRewriter with the current threshold
79+
// This is added dynamically because it needs the threshold parameter
80+
allRewriters.add(new TermsMergingRewriter(termsThreshold));
81+
allRewriters.sort(Comparator.comparingInt(QueryRewriter::priority));
82+
return allRewriters;
83+
}
3584

3685
public static QueryBuilder rewrite(QueryBuilder query, QueryShardContext context, boolean enabled) {
3786
return rewrite(query, context, enabled, 16);
@@ -42,17 +91,11 @@ public static QueryBuilder rewrite(QueryBuilder query, QueryShardContext context
4291
return query;
4392
}
4493

45-
// Create rewriters with the current threshold
46-
List<QueryRewriter> currentRewriters = new ArrayList<>();
47-
currentRewriters.add(new BooleanFlatteningRewriter());
48-
currentRewriters.add(new MustToFilterRewriter());
49-
currentRewriters.add(new MustNotToShouldRewriter());
50-
currentRewriters.add(new TermsMergingRewriter(termsThreshold));
51-
currentRewriters.add(new MatchAllRemovalRewriter());
52-
currentRewriters.sort(Comparator.comparingInt(QueryRewriter::priority));
94+
QueryRewriterRegistry registry = getInstance();
95+
List<QueryRewriter> sortedRewriters = registry.getRewritersWithThreshold(termsThreshold);
5396

5497
QueryBuilder current = query;
55-
for (QueryRewriter rewriter : currentRewriters) {
98+
for (QueryRewriter rewriter : sortedRewriters) {
5699
try {
57100
QueryBuilder rewritten = rewriter.rewrite(current, context);
58101
if (rewritten != current) {

server/src/test/java/org/opensearch/search/query/QueryRewriterRegistryTests.java

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,4 +216,85 @@ public void testVeryComplexMixedQuery() {
216216
assertTrue(result.must().size() >= 1);
217217
assertTrue(result.filter().size() >= 1);
218218
}
219+
220+
public void testCustomRewriterRegistration() {
221+
// Create a custom rewriter for testing
222+
QueryRewriter customRewriter = new QueryRewriter() {
223+
@Override
224+
public QueryBuilder rewrite(QueryBuilder query, QueryShardContext context) {
225+
if (query instanceof TermQueryBuilder) {
226+
TermQueryBuilder termQuery = (TermQueryBuilder) query;
227+
if ("test_field".equals(termQuery.fieldName()) && "test_value".equals(termQuery.value())) {
228+
// Replace with a different query
229+
return QueryBuilders.termQuery("custom_field", "custom_value");
230+
}
231+
} else if (query instanceof BoolQueryBuilder) {
232+
// Recursively apply to nested queries
233+
BoolQueryBuilder boolQuery = (BoolQueryBuilder) query;
234+
BoolQueryBuilder rewritten = new BoolQueryBuilder();
235+
236+
// Copy settings
237+
rewritten.boost(boolQuery.boost());
238+
rewritten.queryName(boolQuery.queryName());
239+
rewritten.minimumShouldMatch(boolQuery.minimumShouldMatch());
240+
rewritten.adjustPureNegative(boolQuery.adjustPureNegative());
241+
242+
// Recursively rewrite clauses
243+
boolean changed = false;
244+
for (QueryBuilder must : boolQuery.must()) {
245+
QueryBuilder rewrittenClause = rewrite(must, context);
246+
rewritten.must(rewrittenClause);
247+
if (rewrittenClause != must) changed = true;
248+
}
249+
for (QueryBuilder filter : boolQuery.filter()) {
250+
QueryBuilder rewrittenClause = rewrite(filter, context);
251+
rewritten.filter(rewrittenClause);
252+
if (rewrittenClause != filter) changed = true;
253+
}
254+
for (QueryBuilder should : boolQuery.should()) {
255+
QueryBuilder rewrittenClause = rewrite(should, context);
256+
rewritten.should(rewrittenClause);
257+
if (rewrittenClause != should) changed = true;
258+
}
259+
for (QueryBuilder mustNot : boolQuery.mustNot()) {
260+
QueryBuilder rewrittenClause = rewrite(mustNot, context);
261+
rewritten.mustNot(rewrittenClause);
262+
if (rewrittenClause != mustNot) changed = true;
263+
}
264+
265+
return changed ? rewritten : query;
266+
}
267+
return query;
268+
}
269+
270+
@Override
271+
public int priority() {
272+
return 1000; // High priority to ensure it runs last
273+
}
274+
275+
@Override
276+
public String name() {
277+
return "test_custom_rewriter";
278+
}
279+
};
280+
281+
// Register the custom rewriter
282+
QueryRewriterRegistry.getInstance().registerRewriter(customRewriter);
283+
284+
// Test that it's applied
285+
QueryBuilder query = QueryBuilders.boolQuery()
286+
.must(QueryBuilders.termQuery("test_field", "test_value"))
287+
.filter(QueryBuilders.termQuery("other_field", "other_value"));
288+
289+
QueryBuilder rewritten = QueryRewriterRegistry.rewrite(query, context, true);
290+
assertThat(rewritten, instanceOf(BoolQueryBuilder.class));
291+
BoolQueryBuilder rewrittenBool = (BoolQueryBuilder) rewritten;
292+
293+
// The custom rewriter should have replaced the term query
294+
assertThat(rewrittenBool.must().size(), equalTo(1));
295+
assertThat(rewrittenBool.must().get(0), instanceOf(TermQueryBuilder.class));
296+
TermQueryBuilder mustTerm = (TermQueryBuilder) rewrittenBool.must().get(0);
297+
assertThat(mustTerm.fieldName(), equalTo("custom_field"));
298+
assertThat(mustTerm.value(), equalTo("custom_value"));
299+
}
219300
}

0 commit comments

Comments
 (0)