Skip to content

Commit b407a4f

Browse files
committed
Set model registry for each semantic query builder instance
1 parent 659a79f commit b407a4f

File tree

4 files changed

+24
-22
lines changed

4 files changed

+24
-22
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/SemanticCrossClusterSearchIT.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin;
3838
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
3939
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
40+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
4041

4142
import java.io.IOException;
4243
import java.util.Collection;
@@ -104,8 +105,12 @@ public void testSemanticCrossClusterSearch() throws Exception {
104105
String localIndex = (String) testClusterInfo.get("local.index");
105106
String remoteIndex = (String) testClusterInfo.get("remote.index");
106107

108+
ModelRegistry modelRegistry = cluster(LOCAL_CLUSTER).getCurrentMasterNodeInstance(ModelRegistry.class);
109+
SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder(INFERENCE_FIELD, "foo");
110+
queryBuilder.setModelRegistrySupplier(() -> modelRegistry);
111+
107112
SearchRequest searchRequest = new SearchRequest(localIndex, REMOTE_CLUSTER + ":" + remoteIndex);
108-
searchRequest.source(new SearchSourceBuilder().query(new SemanticQueryBuilder(INFERENCE_FIELD, "foo")).size(10));
113+
searchRequest.source(new SearchSourceBuilder().query(queryBuilder).size(10));
109114
searchRequest.setCcsMinimizeRoundtrips(true);
110115

111116
assertResponse(client(LOCAL_CLUSTER).search(searchRequest), response -> {
@@ -155,12 +160,7 @@ public void testMatchCrossClusterSearch() throws Exception {
155160
private Map<String, Object> setupTwoClusters(String[] localIndices, String[] remoteIndices) throws IOException {
156161
final String localInferenceId = "local_inference_id";
157162
final String remoteInferenceId = "remote_inference_id";
158-
159-
// TODO: Resolve bug where remote model registry overwrites local model registry in SemanticQueryBuilder
160163
createInferenceEndpoint(client(LOCAL_CLUSTER), TaskType.TEXT_EMBEDDING, localInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS_1);
161-
createInferenceEndpoint(client(LOCAL_CLUSTER), TaskType.TEXT_EMBEDDING, remoteInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS_2);
162-
163-
createInferenceEndpoint(client(REMOTE_CLUSTER), TaskType.TEXT_EMBEDDING, localInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS_1);
164164
createInferenceEndpoint(client(REMOTE_CLUSTER), TaskType.TEXT_EMBEDDING, remoteInferenceId, TEXT_EMBEDDING_SERVICE_SETTINGS_2);
165165

166166
int numShardsLocal = randomIntBetween(2, 10);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,15 @@ public Collection<MappedActionFilter> getMappedActionFilters() {
568568
}
569569

570570
public List<QuerySpec<?>> getQueries() {
571-
return List.of(new QuerySpec<>(SemanticQueryBuilder.NAME, SemanticQueryBuilder::new, SemanticQueryBuilder::fromXContent));
571+
return List.of(new QuerySpec<>(SemanticQueryBuilder.NAME, i -> {
572+
SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder(i);
573+
queryBuilder.setModelRegistrySupplier(getModelRegistry());
574+
return queryBuilder;
575+
}, p -> {
576+
SemanticQueryBuilder queryBuilder = SemanticQueryBuilder.fromXContent(p);
577+
queryBuilder.setModelRegistrySupplier(getModelRegistry());
578+
return queryBuilder;
579+
}));
572580
}
573581

574582
@Override
@@ -602,8 +610,6 @@ public void onNodeStarted() {
602610
if (registry != null) {
603611
registry.onNodeStarted();
604612
}
605-
606-
SemanticQueryBuilder.setModelRegistrySupplier(getModelRegistry());
607613
}
608614

609615
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,14 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
7474
declareStandardFields(PARSER);
7575
}
7676

77-
private static Supplier<ModelRegistry> MODEL_REGISTRY_SUPPLIER = () -> null;
78-
79-
public static void setModelRegistrySupplier(Supplier<ModelRegistry> supplier) {
80-
MODEL_REGISTRY_SUPPLIER = supplier;
81-
}
82-
8377
private final String fieldName;
8478
private final String query;
8579
private final EmbeddingsProvider embeddingsProvider;
8680
private final boolean noInferenceResults;
8781
private final Boolean lenient;
8882

83+
private Supplier<ModelRegistry> modelRegistrySupplier = () -> null;
84+
8985
public SemanticQueryBuilder(String fieldName, String query) {
9086
this(fieldName, query, null);
9187
}
@@ -126,6 +122,10 @@ public SemanticQueryBuilder(StreamInput in) throws IOException {
126122
}
127123
}
128124

125+
public void setModelRegistrySupplier(Supplier<ModelRegistry> supplier) {
126+
modelRegistrySupplier = supplier;
127+
}
128+
129129
@Override
130130
protected void doWriteTo(StreamOutput out) throws IOException {
131131
out.writeString(fieldName);
@@ -150,6 +150,7 @@ private SemanticQueryBuilder(SemanticQueryBuilder other, EmbeddingsProvider embe
150150
this.embeddingsProvider = embeddingsProvider;
151151
this.noInferenceResults = noInferenceResults;
152152
this.lenient = other.lenient;
153+
this.modelRegistrySupplier = other.modelRegistrySupplier;
153154
}
154155

155156
@Override
@@ -208,7 +209,7 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx
208209
);
209210
}
210211

211-
ModelRegistry modelRegistry = MODEL_REGISTRY_SUPPLIER.get();
212+
ModelRegistry modelRegistry = modelRegistrySupplier.get();
212213
if (modelRegistry == null) {
213214
throw new IllegalStateException("Model registry has not been set");
214215
}
@@ -272,7 +273,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
272273

273274
boolean modified = false;
274275
if (queryRewriteContext.hasAsyncActions() == false) {
275-
ModelRegistry modelRegistry = MODEL_REGISTRY_SUPPLIER.get();
276+
ModelRegistry modelRegistry = modelRegistrySupplier.get();
276277
if (modelRegistry == null) {
277278
throw new IllegalStateException("Model registry has not been set");
278279
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/LocalStateInferencePlugin.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,4 @@ public Map<String, Highlighter> getHighlighters() {
7373
public Collection<MappedActionFilter> getMappedActionFilters() {
7474
return inferencePlugin.getMappedActionFilters();
7575
}
76-
77-
@Override
78-
public void onNodeStarted() {
79-
inferencePlugin.onNodeStarted();
80-
}
8176
}

0 commit comments

Comments
 (0)