4949import org .elasticsearch .xcontent .XContentBuilder ;
5050import org .elasticsearch .xcontent .XContentFactory ;
5151import org .elasticsearch .xpack .core .ClientHelper ;
52+ import org .elasticsearch .xpack .inference .DefaultElserFeatureFlag ;
5253import org .elasticsearch .xpack .inference .InferenceIndex ;
5354import org .elasticsearch .xpack .inference .InferenceSecretsIndex ;
5455import org .elasticsearch .xpack .inference .services .ServiceUtils ;
@@ -117,19 +118,23 @@ public ModelRegistry(Client client) {
117118 * @param defaultConfigIds The defaults
118119 */
119120 public void addDefaultIds (InferenceService .DefaultConfigId defaultConfigIds ) {
120- var matched = idMatchedDefault (defaultConfigIds .inferenceId (), this .defaultConfigIds );
121- if (matched .isPresent ()) {
122- throw new IllegalStateException (
123- "Cannot add default endpoint to the inference endpoint registry with duplicate inference id ["
124- + defaultConfigIds .inferenceId ()
125- + "] declared by service ["
126- + defaultConfigIds .service ().name ()
127- + "]. The inference Id is already use by ["
128- + matched .get ().service ().name ()
129- + "] service."
130- );
121+ if (DefaultElserFeatureFlag .isEnabled ()) {
122+ var matched = idMatchedDefault (defaultConfigIds .inferenceId (), this .defaultConfigIds );
123+ if (matched .isPresent ()) {
124+ throw new IllegalStateException (
125+ "Cannot add default endpoint to the inference endpoint registry with duplicate inference id ["
126+ + defaultConfigIds .inferenceId ()
127+ + "] declared by service ["
128+ + defaultConfigIds .service ().name ()
129+ + "]. The inference Id is already use by ["
130+ + matched .get ().service ().name ()
131+ + "] service."
132+ );
133+ }
134+ this .defaultConfigIds .add (defaultConfigIds );
135+ } else {
136+ logger .error ("Attempted to addDefaultIds [{}] with the feature flag disabled" , defaultConfigIds .inferenceId ());
131137 }
132- this .defaultConfigIds .add (defaultConfigIds );
133138 }
134139
135140 /**
@@ -142,7 +147,7 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener<Unparse
142147 // There should be a hit for the configurations
143148 if (searchResponse .getHits ().getHits ().length == 0 ) {
144149 var maybeDefault = idMatchedDefault (inferenceEntityId , defaultConfigIds );
145- if (maybeDefault .isPresent ()) {
150+ if (DefaultElserFeatureFlag . isEnabled () && maybeDefault .isPresent ()) {
146151 getDefaultConfig (true , maybeDefault .get (), listener );
147152 } else {
148153 delegate .onFailure (inferenceNotFoundException (inferenceEntityId ));
@@ -173,7 +178,7 @@ public void getModel(String inferenceEntityId, ActionListener<UnparsedModel> lis
173178 // There should be a hit for the configurations
174179 if (searchResponse .getHits ().getHits ().length == 0 ) {
175180 var maybeDefault = idMatchedDefault (inferenceEntityId , defaultConfigIds );
176- if (maybeDefault .isPresent ()) {
181+ if (DefaultElserFeatureFlag . isEnabled () && maybeDefault .isPresent ()) {
177182 getDefaultConfig (true , maybeDefault .get (), listener );
178183 } else {
179184 delegate .onFailure (inferenceNotFoundException (inferenceEntityId ));
@@ -209,8 +214,12 @@ private ResourceNotFoundException inferenceNotFoundException(String inferenceEnt
209214 public void getModelsByTaskType (TaskType taskType , ActionListener <List <UnparsedModel >> listener ) {
210215 ActionListener <SearchResponse > searchListener = listener .delegateFailureAndWrap ((delegate , searchResponse ) -> {
211216 var modelConfigs = parseHitsAsModels (searchResponse .getHits ()).stream ().map (ModelRegistry ::unparsedModelFromMap ).toList ();
212- var defaultConfigsForTaskType = taskTypeMatchedDefaults (taskType , defaultConfigIds );
213- addAllDefaultConfigsIfMissing (true , modelConfigs , defaultConfigsForTaskType , delegate );
217+ if (DefaultElserFeatureFlag .isEnabled ()) {
218+ var defaultConfigsForTaskType = taskTypeMatchedDefaults (taskType , defaultConfigIds );
219+ addAllDefaultConfigsIfMissing (true , modelConfigs , defaultConfigsForTaskType , delegate );
220+ } else {
221+ delegate .onResponse (modelConfigs );
222+ }
214223 });
215224
216225 QueryBuilder queryBuilder = QueryBuilders .constantScoreQuery (QueryBuilders .termsQuery (TASK_TYPE_FIELD , taskType .toString ()));
@@ -240,7 +249,11 @@ public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedM
240249 public void getAllModels (boolean persistDefaultEndpoints , ActionListener <List <UnparsedModel >> listener ) {
241250 ActionListener <SearchResponse > searchListener = listener .delegateFailureAndWrap ((delegate , searchResponse ) -> {
242251 var foundConfigs = parseHitsAsModels (searchResponse .getHits ()).stream ().map (ModelRegistry ::unparsedModelFromMap ).toList ();
243- addAllDefaultConfigsIfMissing (persistDefaultEndpoints , foundConfigs , defaultConfigIds , delegate );
252+ if (DefaultElserFeatureFlag .isEnabled ()) {
253+ addAllDefaultConfigsIfMissing (persistDefaultEndpoints , foundConfigs , defaultConfigIds , delegate );
254+ } else {
255+ delegate .onResponse (foundConfigs );
256+ }
244257 });
245258
246259 // In theory the index should only contain model config documents
@@ -264,26 +277,32 @@ private void addAllDefaultConfigsIfMissing(
264277 List <InferenceService .DefaultConfigId > matchedDefaults ,
265278 ActionListener <List <UnparsedModel >> listener
266279 ) {
267- var foundIds = foundConfigs .stream ().map (UnparsedModel ::inferenceEntityId ).collect (Collectors .toSet ());
268- var missing = matchedDefaults .stream ().filter (d -> foundIds .contains (d .inferenceId ()) == false ).toList ();
280+ if (DefaultElserFeatureFlag .isEnabled ()) {
269281
270- if (missing .isEmpty ()) {
271- listener .onResponse (foundConfigs );
272- } else {
273- var groupedListener = new GroupedActionListener <UnparsedModel >(
274- missing .size (),
275- listener .delegateFailure ((delegate , listOfModels ) -> {
276- var allConfigs = new ArrayList <UnparsedModel >();
277- allConfigs .addAll (foundConfigs );
278- allConfigs .addAll (listOfModels );
279- allConfigs .sort (Comparator .comparing (UnparsedModel ::inferenceEntityId ));
280- delegate .onResponse (allConfigs );
281- })
282- );
282+ var foundIds = foundConfigs .stream ().map (UnparsedModel ::inferenceEntityId ).collect (Collectors .toSet ());
283+ var missing = matchedDefaults .stream ().filter (d -> foundIds .contains (d .inferenceId ()) == false ).toList ();
284+
285+ if (missing .isEmpty ()) {
286+ listener .onResponse (foundConfigs );
287+ } else {
288+ var groupedListener = new GroupedActionListener <UnparsedModel >(
289+ missing .size (),
290+ listener .delegateFailure ((delegate , listOfModels ) -> {
291+ var allConfigs = new ArrayList <UnparsedModel >();
292+ allConfigs .addAll (foundConfigs );
293+ allConfigs .addAll (listOfModels );
294+ allConfigs .sort (Comparator .comparing (UnparsedModel ::inferenceEntityId ));
295+ delegate .onResponse (allConfigs );
296+ })
297+ );
283298
284- for (var required : missing ) {
285- getDefaultConfig (persistDefaultEndpoints , required , groupedListener );
299+ for (var required : missing ) {
300+ getDefaultConfig (persistDefaultEndpoints , required , groupedListener );
301+ }
286302 }
303+ } else {
304+ logger .error ("Attempted to add default configs with the feature flag disabled" );
305+ assert false ;
287306 }
288307 }
289308
@@ -292,40 +311,52 @@ private void getDefaultConfig(
292311 InferenceService .DefaultConfigId defaultConfig ,
293312 ActionListener <UnparsedModel > listener
294313 ) {
295- defaultConfig .service ().defaultConfigs (listener .delegateFailureAndWrap ((delegate , models ) -> {
296- boolean foundModel = false ;
297- for (var m : models ) {
298- if (m .getInferenceEntityId ().equals (defaultConfig .inferenceId ())) {
299- foundModel = true ;
300- if (persistDefaultEndpoints ) {
301- storeDefaultEndpoint (m , () -> listener .onResponse (modelToUnparsedModel (m )));
302- } else {
303- listener .onResponse (modelToUnparsedModel (m ));
314+ if (DefaultElserFeatureFlag .isEnabled ()) {
315+
316+ defaultConfig .service ().defaultConfigs (listener .delegateFailureAndWrap ((delegate , models ) -> {
317+ boolean foundModel = false ;
318+ for (var m : models ) {
319+ if (m .getInferenceEntityId ().equals (defaultConfig .inferenceId ())) {
320+ foundModel = true ;
321+ if (persistDefaultEndpoints ) {
322+ storeDefaultEndpoint (m , () -> listener .onResponse (modelToUnparsedModel (m )));
323+ } else {
324+ listener .onResponse (modelToUnparsedModel (m ));
325+ }
326+ break ;
304327 }
305- break ;
306328 }
307- }
308329
309- if (foundModel == false ) {
310- listener .onFailure (
311- new IllegalStateException ("Configuration not found for default inference id [" + defaultConfig .inferenceId () + "]" )
312- );
313- }
314- }));
330+ if (foundModel == false ) {
331+ listener .onFailure (
332+ new IllegalStateException ("Configuration not found for default inference id [" + defaultConfig .inferenceId () + "]" )
333+ );
334+ }
335+ }));
336+ } else {
337+ logger .error ("Attempted to get default configs with the feature flag disabled" );
338+ assert false ;
339+ }
315340 }
316341
317342 private void storeDefaultEndpoint (Model preconfigured , Runnable runAfter ) {
318- var responseListener = ActionListener .<Boolean >wrap (success -> {
319- logger .debug ("Added default inference endpoint [{}]" , preconfigured .getInferenceEntityId ());
320- }, exception -> {
321- if (exception instanceof ResourceAlreadyExistsException ) {
322- logger .debug ("Default inference id [{}] already exists" , preconfigured .getInferenceEntityId ());
323- } else {
324- logger .error ("Failed to store default inference id [" + preconfigured .getInferenceEntityId () + "]" , exception );
325- }
326- });
343+ if (DefaultElserFeatureFlag .isEnabled ()) {
344+
345+ var responseListener = ActionListener .<Boolean >wrap (success -> {
346+ logger .debug ("Added default inference endpoint [{}]" , preconfigured .getInferenceEntityId ());
347+ }, exception -> {
348+ if (exception instanceof ResourceAlreadyExistsException ) {
349+ logger .debug ("Default inference id [{}] already exists" , preconfigured .getInferenceEntityId ());
350+ } else {
351+ logger .error ("Failed to store default inference id [" + preconfigured .getInferenceEntityId () + "]" , exception );
352+ }
353+ });
327354
328- storeModel (preconfigured , ActionListener .runAfter (responseListener , runAfter ));
355+ storeModel (preconfigured , ActionListener .runAfter (responseListener , runAfter ));
356+ } else {
357+ logger .error ("Attempted to store default endpoint with the feature flag disabled" );
358+ assert false ;
359+ }
329360 }
330361
331362 private ArrayList <ModelConfigMap > parseHitsAsModels (SearchHits hits ) {
@@ -673,6 +704,7 @@ static List<InferenceService.DefaultConfigId> taskTypeMatchedDefaults(
673704 TaskType taskType ,
674705 List <InferenceService .DefaultConfigId > defaultConfigIds
675706 ) {
707+ assert DefaultElserFeatureFlag .isEnabled ();
676708 return defaultConfigIds .stream ()
677709 .filter (defaultConfigId -> defaultConfigId .taskType ().equals (taskType ))
678710 .collect (Collectors .toList ());
0 commit comments