Skip to content

Commit 8db87ab

Browse files
committed
Add model registry to semantic query builder
1 parent 1e5cb6f commit 8db87ab

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,10 +592,11 @@ public Map<String, Highlighter> getHighlighters() {
592592
@Override
593593
public void onNodeStarted() {
594594
var registry = inferenceServiceRegistry.get();
595-
596595
if (registry != null) {
597596
registry.onNodeStarted();
598597
}
598+
599+
SemanticQueryBuilder.setModelRegistrySupplier(getModelRegistry());
599600
}
600601

601602
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@
3636
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
3737
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
3838
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
39+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
3940

4041
import java.io.IOException;
4142
import java.util.Collection;
4243
import java.util.List;
4344
import java.util.Map;
4445
import java.util.Objects;
46+
import java.util.function.Supplier;
4547

4648
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
4749
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
@@ -68,6 +70,12 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
6870
declareStandardFields(PARSER);
6971
}
7072

73+
private static Supplier<ModelRegistry> MODEL_REGISTRY_SUPPLIER = () -> null;
74+
75+
public static void setModelRegistrySupplier(Supplier<ModelRegistry> supplier) {
76+
MODEL_REGISTRY_SUPPLIER = supplier;
77+
}
78+
7179
private final String fieldName;
7280
private final String query;
7381
private final SetOnce<InferenceServiceResults> inferenceResultsSupplier;

0 commit comments

Comments
 (0)