Skip to content

Commit 43dda15

Browse files
committed
Check that model registry is set
1 parent 42d5eab commit 43dda15

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,13 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx
209209
);
210210
}
211211

212-
// TODO: Check that model registry supplier has been set
212+
ModelRegistry modelRegistry = MODEL_REGISTRY_SUPPLIER.get();
213+
if (modelRegistry == null) {
214+
throw new IllegalStateException("Model registry has not been set");
215+
}
216+
213217
String inferenceId = semanticTextFieldType.getSearchInferenceId();
214-
MinimalServiceSettings serviceSettings = MODEL_REGISTRY_SUPPLIER.get().getMinimalServiceSettings(inferenceId);
218+
MinimalServiceSettings serviceSettings = modelRegistry.getMinimalServiceSettings(inferenceId);
215219
InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings);
216220
InferenceResults inferenceResults = embeddingsProvider.getEmbeddings(inferenceEndpointKey);
217221

@@ -273,8 +277,12 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
273277
false
274278
);
275279

276-
// TODO: Check that model registry supplier has been set
277-
MinimalServiceSettings serviceSettings = MODEL_REGISTRY_SUPPLIER.get().getMinimalServiceSettings(inferenceId);
280+
ModelRegistry modelRegistry = MODEL_REGISTRY_SUPPLIER.get();
281+
if (modelRegistry == null) {
282+
throw new IllegalStateException("Model registry has not been set");
283+
}
284+
285+
MinimalServiceSettings serviceSettings = modelRegistry.getMinimalServiceSettings(inferenceId);
278286
InferenceEndpointKey inferenceEndpointKey = new InferenceEndpointKey(inferenceId, serviceSettings);
279287
queryRewriteContext.registerAsyncAction(
280288
(client, listener) -> executeAsyncWithOrigin(

0 commit comments

Comments
 (0)