1111import  org .elasticsearch .TransportVersion ;
1212import  org .elasticsearch .action .ActionListener ;
1313import  org .elasticsearch .client .internal .Client ;
14+ import  org .elasticsearch .cluster .service .ClusterService ;
1415import  org .elasticsearch .common .io .stream .StreamOutput ;
16+ import  org .elasticsearch .common .settings .Settings ;
17+ import  org .elasticsearch .inference .InferenceService ;
1518import  org .elasticsearch .inference .InferenceServiceExtension ;
1619import  org .elasticsearch .inference .Model ;
1720import  org .elasticsearch .inference .ModelConfigurations ;
4750import  java .util .concurrent .CountDownLatch ;
4851import  java .util .concurrent .atomic .AtomicReference ;
4952import  java .util .function .Consumer ;
53+ import  java .util .function .Function ;
5054import  java .util .stream .Collectors ;
5155
5256import  static  org .hamcrest .CoreMatchers .equalTo ;
5761import  static  org .hamcrest .Matchers .instanceOf ;
5862import  static  org .hamcrest .Matchers .not ;
5963import  static  org .hamcrest .Matchers .nullValue ;
64+ import  static  org .mockito .ArgumentMatchers .any ;
65+ import  static  org .mockito .Mockito .doAnswer ;
6066import  static  org .mockito .Mockito .mock ;
6167
6268public  class  ModelRegistryIT  extends  ESSingleNodeTestCase  {
@@ -122,7 +128,12 @@ public void testGetModel() throws Exception {
122128        assertEquals (model .getConfigurations ().getService (), modelHolder .get ().service ());
123129
124130        var  elserService  = new  ElasticsearchInternalService (
125-             new  InferenceServiceExtension .InferenceServiceFactoryContext (mock (Client .class ), mock (ThreadPool .class ))
131+             new  InferenceServiceExtension .InferenceServiceFactoryContext (
132+                 mock (Client .class ),
133+                 mock (ThreadPool .class ),
134+                 mock (ClusterService .class ),
135+                 Settings .EMPTY 
136+             )
126137        );
127138        ElasticsearchInternalModel  roundTripModel  = (ElasticsearchInternalModel ) elserService .parsePersistedConfigWithSecrets (
128139            modelHolder .get ().inferenceEntityId (),
@@ -283,26 +294,38 @@ public void testGetModelWithSecrets() throws InterruptedException {
283294    }
284295
285296    public  void  testGetAllModels_WithDefaults () throws  Exception  {
286-         var  service  = "foo" ;
287-         var  secret  = "abc" ;
297+         var  serviceName  = "foo" ;
288298        int  configuredModelCount  = 10 ;
289299        int  defaultModelCount  = 2 ;
290300        int  totalModelCount  = 12 ;
291301
292-         var  defaultConfigs  = new  HashMap <String , UnparsedModel >();
302+         var  service  = mock (InferenceService .class );
303+ 
304+         var  defaultConfigs  = new  ArrayList <Model >();
305+         var  defaultIds  = new  ArrayList <InferenceService .DefaultConfigId >();
293306        for  (int  i  = 0 ; i  < defaultModelCount ; i ++) {
294307            var  id  = "default-"  + i ;
295-             defaultConfigs .put (id , createUnparsedConfig (id , randomFrom (TaskType .values ()), service , secret ));
308+             var  taskType  = randomFrom (TaskType .values ());
309+             defaultConfigs .add (createModel (id , taskType , serviceName ));
310+             defaultIds .add (new  InferenceService .DefaultConfigId (id , taskType , service ));
296311        }
297-         defaultConfigs .values ().forEach (modelRegistry ::addDefaultConfiguration );
312+ 
313+         doAnswer (invocation  -> {
314+             @ SuppressWarnings ("unchecked" )
315+             var  listener  = (ActionListener <List <Model >>) invocation .getArguments ()[0 ];
316+             listener .onResponse (defaultConfigs );
317+             return  Void .TYPE ;
318+         }).when (service ).defaultConfigs (any ());
319+ 
320+         defaultIds .forEach (modelRegistry ::addDefaultIds );
298321
299322        AtomicReference <Boolean > putModelHolder  = new  AtomicReference <>();
300323        AtomicReference <Exception > exceptionHolder  = new  AtomicReference <>();
301324
302325        var  createdModels  = new  HashMap <String , Model >();
303326        for  (int  i  = 0 ; i  < configuredModelCount ; i ++) {
304327            var  id  = randomAlphaOfLength (5 ) + i ;
305-             var  model  = createModel (id , randomFrom (TaskType .values ()), service );
328+             var  model  = createModel (id , randomFrom (TaskType .values ()), serviceName );
306329            createdModels .put (id , model );
307330            blockingCall (listener  -> modelRegistry .storeModel (model , listener ), putModelHolder , exceptionHolder );
308331            assertThat (putModelHolder .get (), is (true ));
@@ -316,16 +339,22 @@ public void testGetAllModels_WithDefaults() throws Exception {
316339        var  getAllModels  = modelHolder .get ();
317340        assertReturnModelIsModifiable (modelHolder .get ().get (0 ));
318341
342+         // same result but configs should have been persisted this time 
343+         blockingCall (listener  -> modelRegistry .getAllModels (listener ), modelHolder , exceptionHolder );
344+         assertNull (exceptionHolder .get ());
345+         assertThat (modelHolder .get (), hasSize (totalModelCount ));
346+ 
319347        // sort in the same order as the returned models 
320-         var  ids  = new  ArrayList <>(defaultConfigs . keySet ().stream ( ).toList ());
348+         var  ids  = new  ArrayList <>(defaultIds . stream ().map ( InferenceService . DefaultConfigId :: inferenceId ).toList ());
321349        ids .addAll (createdModels .keySet ().stream ().toList ());
322350        ids .sort (String ::compareTo );
351+         var  configsById  = defaultConfigs .stream ().collect (Collectors .toMap (Model ::getInferenceEntityId , Function .identity ()));
323352        for  (int  i  = 0 ; i  < totalModelCount ; i ++) {
324353            var  id  = ids .get (i );
325354            assertEquals (id , getAllModels .get (i ).inferenceEntityId ());
326355            if  (id .startsWith ("default" )) {
327-                 assertEquals (defaultConfigs .get (id ).taskType (), getAllModels .get (i ).taskType ());
328-                 assertEquals (defaultConfigs .get (id ).service (), getAllModels .get (i ).service ());
356+                 assertEquals (configsById .get (id ).getTaskType (), getAllModels .get (i ).taskType ());
357+                 assertEquals (configsById .get (id ).getConfigurations (). getService (), getAllModels .get (i ).service ());
329358            } else  {
330359                assertEquals (createdModels .get (id ).getTaskType (), getAllModels .get (i ).taskType ());
331360                assertEquals (createdModels .get (id ).getConfigurations ().getService (), getAllModels .get (i ).service ());
@@ -334,16 +363,27 @@ public void testGetAllModels_WithDefaults() throws Exception {
334363    }
335364
336365    public  void  testGetAllModels_OnlyDefaults () throws  Exception  {
337-         var  service  = "foo" ;
338-         var  secret  = "abc" ;
339366        int  defaultModelCount  = 2 ;
367+         var  serviceName  = "foo" ;
368+         var  service  = mock (InferenceService .class );
340369
341-         var  defaultConfigs  = new  HashMap <String , UnparsedModel >();
370+         var  defaultConfigs  = new  ArrayList <Model >();
371+         var  defaultIds  = new  ArrayList <InferenceService .DefaultConfigId >();
342372        for  (int  i  = 0 ; i  < defaultModelCount ; i ++) {
343373            var  id  = "default-"  + i ;
344-             defaultConfigs .put (id , createUnparsedConfig (id , randomFrom (TaskType .values ()), service , secret ));
374+             var  taskType  = randomFrom (TaskType .values ());
375+             defaultConfigs .add (createModel (id , taskType , serviceName ));
376+             defaultIds .add (new  InferenceService .DefaultConfigId (id , taskType , service ));
345377        }
346-         defaultConfigs .values ().forEach (modelRegistry ::addDefaultConfiguration );
378+ 
379+         doAnswer (invocation  -> {
380+             @ SuppressWarnings ("unchecked" )
381+             var  listener  = (ActionListener <List <Model >>) invocation .getArguments ()[0 ];
382+             listener .onResponse (defaultConfigs );
383+             return  Void .TYPE ;
384+         }).when (service ).defaultConfigs (any ());
385+ 
386+         defaultIds .forEach (modelRegistry ::addDefaultIds );
347387
348388        AtomicReference <Exception > exceptionHolder  = new  AtomicReference <>();
349389        AtomicReference <List <UnparsedModel >> modelHolder  = new  AtomicReference <>();
@@ -354,31 +394,42 @@ public void testGetAllModels_OnlyDefaults() throws Exception {
354394        assertReturnModelIsModifiable (modelHolder .get ().get (0 ));
355395
356396        // sort in the same order as the returned models 
357-         var  ids  = new  ArrayList <>(defaultConfigs .keySet ().stream ().toList ());
397+         var  configsById  = defaultConfigs .stream ().collect (Collectors .toMap (Model ::getInferenceEntityId , Function .identity ()));
398+         var  ids  = new  ArrayList <>(configsById .keySet ().stream ().toList ());
358399        ids .sort (String ::compareTo );
359400        for  (int  i  = 0 ; i  < defaultModelCount ; i ++) {
360401            var  id  = ids .get (i );
361402            assertEquals (id , getAllModels .get (i ).inferenceEntityId ());
362-             assertEquals (defaultConfigs .get (id ).taskType (), getAllModels .get (i ).taskType ());
363-             assertEquals (defaultConfigs .get (id ).service (), getAllModels .get (i ).service ());
403+             assertEquals (configsById .get (id ).getTaskType (), getAllModels .get (i ).taskType ());
404+             assertEquals (configsById .get (id ).getConfigurations (). getService (), getAllModels .get (i ).service ());
364405        }
365406    }
366407
367408    public  void  testGet_WithDefaults () throws  InterruptedException  {
368-         var  service  = "foo" ;
369-         var  secret  = "abc" ;
409+         var  serviceName  = "foo" ;
410+         var  service  = mock (InferenceService .class );
411+ 
412+         var  defaultConfigs  = new  ArrayList <Model >();
413+         var  defaultIds  = new  ArrayList <InferenceService .DefaultConfigId >();
370414
371-         var  defaultSparse  = createUnparsedConfig ("default-sparse" , TaskType .SPARSE_EMBEDDING , service , secret );
372-         var  defaultText  = createUnparsedConfig ("default-text" , TaskType .TEXT_EMBEDDING , service , secret );
415+         defaultConfigs .add (createModel ("default-sparse" , TaskType .SPARSE_EMBEDDING , serviceName ));
416+         defaultConfigs .add (createModel ("default-text" , TaskType .TEXT_EMBEDDING , serviceName ));
417+         defaultIds .add (new  InferenceService .DefaultConfigId ("default-sparse" , TaskType .SPARSE_EMBEDDING , service ));
418+         defaultIds .add (new  InferenceService .DefaultConfigId ("default-text" , TaskType .TEXT_EMBEDDING , service ));
373419
374-         modelRegistry .addDefaultConfiguration (defaultSparse );
375-         modelRegistry .addDefaultConfiguration (defaultText );
420+         doAnswer (invocation  -> {
421+             @ SuppressWarnings ("unchecked" )
422+             var  listener  = (ActionListener <List <Model >>) invocation .getArguments ()[0 ];
423+             listener .onResponse (defaultConfigs );
424+             return  Void .TYPE ;
425+         }).when (service ).defaultConfigs (any ());
426+         defaultIds .forEach (modelRegistry ::addDefaultIds );
376427
377428        AtomicReference <Boolean > putModelHolder  = new  AtomicReference <>();
378429        AtomicReference <Exception > exceptionHolder  = new  AtomicReference <>();
379430
380-         var  configured1  = createModel (randomAlphaOfLength (5 ) + 1 , randomFrom (TaskType .values ()), service );
381-         var  configured2  = createModel (randomAlphaOfLength (5 ) + 1 , randomFrom (TaskType .values ()), service );
431+         var  configured1  = createModel (randomAlphaOfLength (5 ) + 1 , randomFrom (TaskType .values ()), serviceName );
432+         var  configured2  = createModel (randomAlphaOfLength (5 ) + 1 , randomFrom (TaskType .values ()), serviceName );
382433        blockingCall (listener  -> modelRegistry .storeModel (configured1 , listener ), putModelHolder , exceptionHolder );
383434        assertThat (putModelHolder .get (), is (true ));
384435        blockingCall (listener  -> modelRegistry .storeModel (configured2 , listener ), putModelHolder , exceptionHolder );
@@ -387,6 +438,7 @@ public void testGet_WithDefaults() throws InterruptedException {
387438
388439        AtomicReference <UnparsedModel > modelHolder  = new  AtomicReference <>();
389440        blockingCall (listener  -> modelRegistry .getModel ("default-sparse" , listener ), modelHolder , exceptionHolder );
441+         assertNull (exceptionHolder .get ());
390442        assertEquals ("default-sparse" , modelHolder .get ().inferenceEntityId ());
391443        assertEquals (TaskType .SPARSE_EMBEDDING , modelHolder .get ().taskType ());
392444        assertReturnModelIsModifiable (modelHolder .get ());
@@ -401,23 +453,32 @@ public void testGet_WithDefaults() throws InterruptedException {
401453    }
402454
403455    public  void  testGetByTaskType_WithDefaults () throws  Exception  {
404-         var  service  = "foo" ;
405-         var  secret  = "abc" ;
406- 
407-         var  defaultSparse  = createUnparsedConfig ("default-sparse" , TaskType .SPARSE_EMBEDDING , service , secret );
408-         var  defaultText  = createUnparsedConfig ("default-text" , TaskType .TEXT_EMBEDDING , service , secret );
409-         var  defaultChat  = createUnparsedConfig ("default-chat" , TaskType .COMPLETION , service , secret );
410- 
411-         modelRegistry .addDefaultConfiguration (defaultSparse );
412-         modelRegistry .addDefaultConfiguration (defaultText );
413-         modelRegistry .addDefaultConfiguration (defaultChat );
456+         var  serviceName  = "foo" ;
457+ 
458+         var  defaultSparse  = createModel ("default-sparse" , TaskType .SPARSE_EMBEDDING , serviceName );
459+         var  defaultText  = createModel ("default-text" , TaskType .TEXT_EMBEDDING , serviceName );
460+         var  defaultChat  = createModel ("default-chat" , TaskType .COMPLETION , serviceName );
461+ 
462+         var  service  = mock (InferenceService .class );
463+         var  defaultIds  = new  ArrayList <InferenceService .DefaultConfigId >();
464+         defaultIds .add (new  InferenceService .DefaultConfigId ("default-sparse" , TaskType .SPARSE_EMBEDDING , service ));
465+         defaultIds .add (new  InferenceService .DefaultConfigId ("default-text" , TaskType .TEXT_EMBEDDING , service ));
466+         defaultIds .add (new  InferenceService .DefaultConfigId ("default-chat" , TaskType .COMPLETION , service ));
467+ 
468+         doAnswer (invocation  -> {
469+             @ SuppressWarnings ("unchecked" )
470+             var  listener  = (ActionListener <List <Model >>) invocation .getArguments ()[0 ];
471+             listener .onResponse (List .of (defaultSparse , defaultChat , defaultText ));
472+             return  Void .TYPE ;
473+         }).when (service ).defaultConfigs (any ());
474+         defaultIds .forEach (modelRegistry ::addDefaultIds );
414475
415476        AtomicReference <Boolean > putModelHolder  = new  AtomicReference <>();
416477        AtomicReference <Exception > exceptionHolder  = new  AtomicReference <>();
417478
418-         var  configuredSparse  = createModel ("configured-sparse" , TaskType .SPARSE_EMBEDDING , service );
419-         var  configuredText  = createModel ("configured-text" , TaskType .TEXT_EMBEDDING , service );
420-         var  configuredRerank  = createModel ("configured-rerank" , TaskType .RERANK , service );
479+         var  configuredSparse  = createModel ("configured-sparse" , TaskType .SPARSE_EMBEDDING , serviceName );
480+         var  configuredText  = createModel ("configured-text" , TaskType .TEXT_EMBEDDING , serviceName );
481+         var  configuredRerank  = createModel ("configured-rerank" , TaskType .RERANK , serviceName );
421482        blockingCall (listener  -> modelRegistry .storeModel (configuredSparse , listener ), putModelHolder , exceptionHolder );
422483        assertThat (putModelHolder .get (), is (true ));
423484        blockingCall (listener  -> modelRegistry .storeModel (configuredText , listener ), putModelHolder , exceptionHolder );
@@ -531,10 +592,6 @@ public static Model createModelWithSecrets(String inferenceEntityId, TaskType ta
531592        );
532593    }
533594
534-     public  static  UnparsedModel  createUnparsedConfig (String  inferenceEntityId , TaskType  taskType , String  service , String  secret ) {
535-         return  new  UnparsedModel (inferenceEntityId , taskType , service , Map .of ("a" , "b" ), Map .of ("secret" , secret ));
536-     }
537- 
538595    private  static  class  TestModelOfAnyKind  extends  ModelConfigurations  {
539596
540597        record  TestModelServiceSettings () implements  ServiceSettings  {
0 commit comments