1111import org .elasticsearch .TransportVersion ;
1212import org .elasticsearch .action .ActionListener ;
1313import org .elasticsearch .client .internal .Client ;
14+ import org .elasticsearch .cluster .service .ClusterService ;
1415import org .elasticsearch .common .io .stream .StreamOutput ;
16+ import org .elasticsearch .common .settings .Settings ;
17+ import org .elasticsearch .inference .InferenceService ;
1518import org .elasticsearch .inference .InferenceServiceExtension ;
1619import org .elasticsearch .inference .Model ;
1720import org .elasticsearch .inference .ModelConfigurations ;
4649import java .util .concurrent .CountDownLatch ;
4750import java .util .concurrent .atomic .AtomicReference ;
4851import java .util .function .Consumer ;
52+ import java .util .function .Function ;
4953import java .util .stream .Collectors ;
5054
5155import static org .hamcrest .CoreMatchers .equalTo ;
5660import static org .hamcrest .Matchers .instanceOf ;
5761import static org .hamcrest .Matchers .not ;
5862import static org .hamcrest .Matchers .nullValue ;
63+ import static org .mockito .ArgumentMatchers .any ;
64+ import static org .mockito .Mockito .doAnswer ;
5965import static org .mockito .Mockito .mock ;
6066
6167public class ModelRegistryIT extends ESSingleNodeTestCase {
@@ -121,7 +127,12 @@ public void testGetModel() throws Exception {
121127 assertEquals (model .getConfigurations ().getService (), modelHolder .get ().service ());
122128
123129 var elserService = new ElasticsearchInternalService (
124- new InferenceServiceExtension .InferenceServiceFactoryContext (mock (Client .class ), mock (ThreadPool .class ))
130+ new InferenceServiceExtension .InferenceServiceFactoryContext (
131+ mock (Client .class ),
132+ mock (ThreadPool .class ),
133+ mock (ClusterService .class ),
134+ Settings .EMPTY
135+ )
125136 );
126137 ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel ) elserService .parsePersistedConfigWithSecrets (
127138 modelHolder .get ().inferenceEntityId (),
@@ -282,26 +293,38 @@ public void testGetModelWithSecrets() throws InterruptedException {
282293 }
283294
284295 public void testGetAllModels_WithDefaults () throws Exception {
285- var service = "foo" ;
286- var secret = "abc" ;
296+ var serviceName = "foo" ;
287297 int configuredModelCount = 10 ;
288298 int defaultModelCount = 2 ;
289299 int totalModelCount = 12 ;
290300
291- var defaultConfigs = new HashMap <String , UnparsedModel >();
301+ var service = mock (InferenceService .class );
302+
303+ var defaultConfigs = new ArrayList <Model >();
304+ var defaultIds = new ArrayList <InferenceService .DefaultConfigId >();
292305 for (int i = 0 ; i < defaultModelCount ; i ++) {
293306 var id = "default-" + i ;
294- defaultConfigs .put (id , createUnparsedConfig (id , randomFrom (TaskType .values ()), service , secret ));
307+ var taskType = randomFrom (TaskType .values ());
308+ defaultConfigs .add (createModel (id , taskType , serviceName ));
309+ defaultIds .add (new InferenceService .DefaultConfigId (id , taskType , service ));
295310 }
296- defaultConfigs .values ().forEach (modelRegistry ::addDefaultConfiguration );
311+
312+ doAnswer (invocation -> {
313+ @ SuppressWarnings ("unchecked" )
314+ var listener = (ActionListener <List <Model >>) invocation .getArguments ()[0 ];
315+ listener .onResponse (defaultConfigs );
316+ return Void .TYPE ;
317+ }).when (service ).defaultConfigs (any ());
318+
319+ defaultIds .forEach (modelRegistry ::addDefaultIds );
297320
298321 AtomicReference <Boolean > putModelHolder = new AtomicReference <>();
299322 AtomicReference <Exception > exceptionHolder = new AtomicReference <>();
300323
301324 var createdModels = new HashMap <String , Model >();
302325 for (int i = 0 ; i < configuredModelCount ; i ++) {
303326 var id = randomAlphaOfLength (5 ) + i ;
304- var model = createModel (id , randomFrom (TaskType .values ()), service );
327+ var model = createModel (id , randomFrom (TaskType .values ()), serviceName );
305328 createdModels .put (id , model );
306329 blockingCall (listener -> modelRegistry .storeModel (model , listener ), putModelHolder , exceptionHolder );
307330 assertThat (putModelHolder .get (), is (true ));
@@ -315,16 +338,22 @@ public void testGetAllModels_WithDefaults() throws Exception {
315338 var getAllModels = modelHolder .get ();
316339 assertReturnModelIsModifiable (modelHolder .get ().get (0 ));
317340
341+ // same result but configs should have been persisted this time
342+ blockingCall (listener -> modelRegistry .getAllModels (listener ), modelHolder , exceptionHolder );
343+ assertNull (exceptionHolder .get ());
344+ assertThat (modelHolder .get (), hasSize (totalModelCount ));
345+
318346 // sort in the same order as the returned models
319- var ids = new ArrayList <>(defaultConfigs . keySet ().stream ( ).toList ());
347+ var ids = new ArrayList <>(defaultIds . stream ().map ( InferenceService . DefaultConfigId :: inferenceId ).toList ());
320348 ids .addAll (createdModels .keySet ().stream ().toList ());
321349 ids .sort (String ::compareTo );
350+ var configsById = defaultConfigs .stream ().collect (Collectors .toMap (Model ::getInferenceEntityId , Function .identity ()));
322351 for (int i = 0 ; i < totalModelCount ; i ++) {
323352 var id = ids .get (i );
324353 assertEquals (id , getAllModels .get (i ).inferenceEntityId ());
325354 if (id .startsWith ("default" )) {
326- assertEquals (defaultConfigs .get (id ).taskType (), getAllModels .get (i ).taskType ());
327- assertEquals (defaultConfigs .get (id ).service (), getAllModels .get (i ).service ());
355+ assertEquals (configsById .get (id ).getTaskType (), getAllModels .get (i ).taskType ());
356+ assertEquals (configsById .get (id ).getConfigurations (). getService (), getAllModels .get (i ).service ());
328357 } else {
329358 assertEquals (createdModels .get (id ).getTaskType (), getAllModels .get (i ).taskType ());
330359 assertEquals (createdModels .get (id ).getConfigurations ().getService (), getAllModels .get (i ).service ());
@@ -333,16 +362,27 @@ public void testGetAllModels_WithDefaults() throws Exception {
333362 }
334363
335364 public void testGetAllModels_OnlyDefaults () throws Exception {
336- var service = "foo" ;
337- var secret = "abc" ;
338365 int defaultModelCount = 2 ;
366+ var serviceName = "foo" ;
367+ var service = mock (InferenceService .class );
339368
340- var defaultConfigs = new HashMap <String , UnparsedModel >();
369+ var defaultConfigs = new ArrayList <Model >();
370+ var defaultIds = new ArrayList <InferenceService .DefaultConfigId >();
341371 for (int i = 0 ; i < defaultModelCount ; i ++) {
342372 var id = "default-" + i ;
343- defaultConfigs .put (id , createUnparsedConfig (id , randomFrom (TaskType .values ()), service , secret ));
373+ var taskType = randomFrom (TaskType .values ());
374+ defaultConfigs .add (createModel (id , taskType , serviceName ));
375+ defaultIds .add (new InferenceService .DefaultConfigId (id , taskType , service ));
344376 }
345- defaultConfigs .values ().forEach (modelRegistry ::addDefaultConfiguration );
377+
378+ doAnswer (invocation -> {
379+ @ SuppressWarnings ("unchecked" )
380+ var listener = (ActionListener <List <Model >>) invocation .getArguments ()[0 ];
381+ listener .onResponse (defaultConfigs );
382+ return Void .TYPE ;
383+ }).when (service ).defaultConfigs (any ());
384+
385+ defaultIds .forEach (modelRegistry ::addDefaultIds );
346386
347387 AtomicReference <Exception > exceptionHolder = new AtomicReference <>();
348388 AtomicReference <List <UnparsedModel >> modelHolder = new AtomicReference <>();
@@ -353,31 +393,42 @@ public void testGetAllModels_OnlyDefaults() throws Exception {
353393 assertReturnModelIsModifiable (modelHolder .get ().get (0 ));
354394
355395 // sort in the same order as the returned models
356- var ids = new ArrayList <>(defaultConfigs .keySet ().stream ().toList ());
396+ var configsById = defaultConfigs .stream ().collect (Collectors .toMap (Model ::getInferenceEntityId , Function .identity ()));
397+ var ids = new ArrayList <>(configsById .keySet ().stream ().toList ());
357398 ids .sort (String ::compareTo );
358399 for (int i = 0 ; i < defaultModelCount ; i ++) {
359400 var id = ids .get (i );
360401 assertEquals (id , getAllModels .get (i ).inferenceEntityId ());
361- assertEquals (defaultConfigs .get (id ).taskType (), getAllModels .get (i ).taskType ());
362- assertEquals (defaultConfigs .get (id ).service (), getAllModels .get (i ).service ());
402+ assertEquals (configsById .get (id ).getTaskType (), getAllModels .get (i ).taskType ());
403+ assertEquals (configsById .get (id ).getConfigurations (). getService (), getAllModels .get (i ).service ());
363404 }
364405 }
365406
366407 public void testGet_WithDefaults () throws InterruptedException {
367- var service = "foo" ;
368- var secret = "abc" ;
408+ var serviceName = "foo" ;
409+ var service = mock (InferenceService .class );
410+
411+ var defaultConfigs = new ArrayList <Model >();
412+ var defaultIds = new ArrayList <InferenceService .DefaultConfigId >();
369413
370- var defaultSparse = createUnparsedConfig ("default-sparse" , TaskType .SPARSE_EMBEDDING , service , secret );
371- var defaultText = createUnparsedConfig ("default-text" , TaskType .TEXT_EMBEDDING , service , secret );
414+ defaultConfigs .add (createModel ("default-sparse" , TaskType .SPARSE_EMBEDDING , serviceName ));
415+ defaultConfigs .add (createModel ("default-text" , TaskType .TEXT_EMBEDDING , serviceName ));
416+ defaultIds .add (new InferenceService .DefaultConfigId ("default-sparse" , TaskType .SPARSE_EMBEDDING , service ));
417+ defaultIds .add (new InferenceService .DefaultConfigId ("default-text" , TaskType .TEXT_EMBEDDING , service ));
372418
373- modelRegistry .addDefaultConfiguration (defaultSparse );
374- modelRegistry .addDefaultConfiguration (defaultText );
419+ doAnswer (invocation -> {
420+ @ SuppressWarnings ("unchecked" )
421+ var listener = (ActionListener <List <Model >>) invocation .getArguments ()[0 ];
422+ listener .onResponse (defaultConfigs );
423+ return Void .TYPE ;
424+ }).when (service ).defaultConfigs (any ());
425+ defaultIds .forEach (modelRegistry ::addDefaultIds );
375426
376427 AtomicReference <Boolean > putModelHolder = new AtomicReference <>();
377428 AtomicReference <Exception > exceptionHolder = new AtomicReference <>();
378429
379- var configured1 = createModel (randomAlphaOfLength (5 ) + 1 , randomFrom (TaskType .values ()), service );
380- var configured2 = createModel (randomAlphaOfLength (5 ) + 1 , randomFrom (TaskType .values ()), service );
430+ var configured1 = createModel (randomAlphaOfLength (5 ) + 1 , randomFrom (TaskType .values ()), serviceName );
431+ var configured2 = createModel (randomAlphaOfLength (5 ) + 1 , randomFrom (TaskType .values ()), serviceName );
381432 blockingCall (listener -> modelRegistry .storeModel (configured1 , listener ), putModelHolder , exceptionHolder );
382433 assertThat (putModelHolder .get (), is (true ));
383434 blockingCall (listener -> modelRegistry .storeModel (configured2 , listener ), putModelHolder , exceptionHolder );
@@ -386,6 +437,7 @@ public void testGet_WithDefaults() throws InterruptedException {
386437
387438 AtomicReference <UnparsedModel > modelHolder = new AtomicReference <>();
388439 blockingCall (listener -> modelRegistry .getModel ("default-sparse" , listener ), modelHolder , exceptionHolder );
440+ assertNull (exceptionHolder .get ());
389441 assertEquals ("default-sparse" , modelHolder .get ().inferenceEntityId ());
390442 assertEquals (TaskType .SPARSE_EMBEDDING , modelHolder .get ().taskType ());
391443 assertReturnModelIsModifiable (modelHolder .get ());
@@ -400,23 +452,32 @@ public void testGet_WithDefaults() throws InterruptedException {
400452 }
401453
402454 public void testGetByTaskType_WithDefaults () throws Exception {
403- var service = "foo" ;
404- var secret = "abc" ;
405-
406- var defaultSparse = createUnparsedConfig ("default-sparse" , TaskType .SPARSE_EMBEDDING , service , secret );
407- var defaultText = createUnparsedConfig ("default-text" , TaskType .TEXT_EMBEDDING , service , secret );
408- var defaultChat = createUnparsedConfig ("default-chat" , TaskType .COMPLETION , service , secret );
409-
410- modelRegistry .addDefaultConfiguration (defaultSparse );
411- modelRegistry .addDefaultConfiguration (defaultText );
412- modelRegistry .addDefaultConfiguration (defaultChat );
455+ var serviceName = "foo" ;
456+
457+ var defaultSparse = createModel ("default-sparse" , TaskType .SPARSE_EMBEDDING , serviceName );
458+ var defaultText = createModel ("default-text" , TaskType .TEXT_EMBEDDING , serviceName );
459+ var defaultChat = createModel ("default-chat" , TaskType .COMPLETION , serviceName );
460+
461+ var service = mock (InferenceService .class );
462+ var defaultIds = new ArrayList <InferenceService .DefaultConfigId >();
463+ defaultIds .add (new InferenceService .DefaultConfigId ("default-sparse" , TaskType .SPARSE_EMBEDDING , service ));
464+ defaultIds .add (new InferenceService .DefaultConfigId ("default-text" , TaskType .TEXT_EMBEDDING , service ));
465+ defaultIds .add (new InferenceService .DefaultConfigId ("default-chat" , TaskType .COMPLETION , service ));
466+
467+ doAnswer (invocation -> {
468+ @ SuppressWarnings ("unchecked" )
469+ var listener = (ActionListener <List <Model >>) invocation .getArguments ()[0 ];
470+ listener .onResponse (List .of (defaultSparse , defaultChat , defaultText ));
471+ return Void .TYPE ;
472+ }).when (service ).defaultConfigs (any ());
473+ defaultIds .forEach (modelRegistry ::addDefaultIds );
413474
414475 AtomicReference <Boolean > putModelHolder = new AtomicReference <>();
415476 AtomicReference <Exception > exceptionHolder = new AtomicReference <>();
416477
417- var configuredSparse = createModel ("configured-sparse" , TaskType .SPARSE_EMBEDDING , service );
418- var configuredText = createModel ("configured-text" , TaskType .TEXT_EMBEDDING , service );
419- var configuredRerank = createModel ("configured-rerank" , TaskType .RERANK , service );
478+ var configuredSparse = createModel ("configured-sparse" , TaskType .SPARSE_EMBEDDING , serviceName );
479+ var configuredText = createModel ("configured-text" , TaskType .TEXT_EMBEDDING , serviceName );
480+ var configuredRerank = createModel ("configured-rerank" , TaskType .RERANK , serviceName );
420481 blockingCall (listener -> modelRegistry .storeModel (configuredSparse , listener ), putModelHolder , exceptionHolder );
421482 assertThat (putModelHolder .get (), is (true ));
422483 blockingCall (listener -> modelRegistry .storeModel (configuredText , listener ), putModelHolder , exceptionHolder );
@@ -530,10 +591,6 @@ public static Model createModelWithSecrets(String inferenceEntityId, TaskType ta
530591 );
531592 }
532593
533- public static UnparsedModel createUnparsedConfig (String inferenceEntityId , TaskType taskType , String service , String secret ) {
534- return new UnparsedModel (inferenceEntityId , taskType , service , Map .of ("a" , "b" ), Map .of ("secret" , secret ));
535- }
536-
537594 private static class TestModelOfAnyKind extends ModelConfigurations {
538595
539596 record TestModelServiceSettings () implements ServiceSettings {
0 commit comments