@@ -51,24 +51,6 @@ public void testAttachToDeployment() throws IOException {
5151 var results = infer (inferenceId , List .of ("washing machine" ));
5252 assertNotNull (results .get ("sparse_embedding" ));
5353
54- var updatedNumAllocations = randomIntBetween (1 , 10 );
55- var updatedEndpointConfig = updateEndpoint (inferenceId , updatedEndpointConfig (updatedNumAllocations ), TaskType .SPARSE_EMBEDDING );
56- assertThat (
57- updatedEndpointConfig .get ("service_settings" ),
58- is (
59- Map .of (
60- "num_allocations" ,
61- updatedNumAllocations ,
62- "num_threads" ,
63- 1 ,
64- "model_id" ,
65- "attach_to_deployment" ,
66- "deployment_id" ,
67- "existing_deployment"
68- )
69- )
70- );
71-
7254 deleteModel (inferenceId );
7355 // assert deployment not stopped
7456 var stats = (List <Map <String , Object >>) getTrainedModelStats (modelId ).get ("trained_model_stats" );
@@ -128,24 +110,6 @@ public void testAttachWithModelId() throws IOException {
128110 var results = infer (inferenceId , List .of ("washing machine" ));
129111 assertNotNull (results .get ("sparse_embedding" ));
130112
131- var updatedNumAllocations = randomIntBetween (1 , 10 );
132- var updatedEndpointConfig = updateEndpoint (inferenceId , updatedEndpointConfig (updatedNumAllocations ), TaskType .SPARSE_EMBEDDING );
133- assertThat (
134- updatedEndpointConfig .get ("service_settings" ),
135- is (
136- Map .of (
137- "num_allocations" ,
138- updatedNumAllocations ,
139- "num_threads" ,
140- 1 ,
141- "model_id" ,
142- "attach_with_model_id" ,
143- "deployment_id" ,
144- "existing_deployment_with_model_id"
145- )
146- )
147- );
148-
149113 forceStopMlNodeDeployment (deploymentId );
150114 }
151115
@@ -180,6 +144,30 @@ public void testDeploymentDoesNotExist() {
180144 assertThat (e .getMessage (), containsString ("Cannot find deployment [missing_deployment]" ));
181145 }
182146
147+ public void testCreateInferenceUsingSameDeploymentId () throws IOException {
148+ var modelId = "conflicting_ids" ;
149+ var deploymentId = modelId ;
150+ var inferenceId = modelId ;
151+
152+ CustomElandModelIT .createMlNodeTextExpansionModel (modelId , client ());
153+ var response = startMlNodeDeploymemnt (modelId , deploymentId );
154+ assertStatusOkOrCreated (response );
155+
156+ var responseException = assertThrows (
157+ ResponseException .class ,
158+ () -> putModel (inferenceId , endpointConfig (deploymentId ), TaskType .SPARSE_EMBEDDING )
159+ );
160+ assertThat (
161+ responseException .getMessage (),
162+ containsString (
163+ "Inference endpoint IDs must be unique. "
164+ + "Requested inference endpoint ID [conflicting_ids] matches existing trained model ID(s) but must not."
165+ )
166+ );
167+
168+ forceStopMlNodeDeployment (deploymentId );
169+ }
170+
183171 public void testNumAllocationsIsUpdated () throws IOException {
184172 var modelId = "update_num_allocations" ;
185173 var deploymentId = modelId ;
@@ -208,7 +196,16 @@ public void testNumAllocationsIsUpdated() throws IOException {
208196 )
209197 );
210198
211- assertStatusOkOrCreated (updateMlNodeDeploymemnt (deploymentId , 2 ));
199+ var responseException = assertThrows (ResponseException .class , () -> updateInference (inferenceId , TaskType .SPARSE_EMBEDDING , 2 ));
200+ assertThat (
201+ responseException .getMessage (),
202+ containsString (
203+ "Cannot update inference endpoint [test_num_allocations_updated] using model deployment [update_num_allocations]. "
204+ + "The model deployment must be updated through the trained models API."
205+ )
206+ );
207+
208+ updateMlNodeDeploymemnt (deploymentId , 2 );
212209
213210 var updatedServiceSettings = getModel (inferenceId ).get ("service_settings" );
214211 assertThat (
@@ -227,6 +224,92 @@ public void testNumAllocationsIsUpdated() throws IOException {
227224 )
228225 )
229226 );
227+
228+ forceStopMlNodeDeployment (deploymentId );
229+ }
230+
231+ public void testUpdateWhenInferenceEndpointCreatesDeployment () throws IOException {
232+ var modelId = "update_num_allocations_from_created_endpoint" ;
233+ var inferenceId = "test_created_endpoint_from_model" ;
234+ var deploymentId = inferenceId ;
235+
236+ CustomElandModelIT .createMlNodeTextExpansionModel (modelId , client ());
237+
238+ var putModel = putModel (inferenceId , Strings .format ("""
239+ {
240+ "service": "elasticsearch",
241+ "service_settings": {
242+ "num_allocations": %s,
243+ "num_threads": %s,
244+ "model_id": "%s"
245+ }
246+ }
247+ """ , 1 , 1 , modelId ), TaskType .SPARSE_EMBEDDING );
248+ var serviceSettings = putModel .get ("service_settings" );
249+ assertThat (putModel .toString (), serviceSettings , is (Map .of ("num_allocations" , 1 , "num_threads" , 1 , "model_id" , modelId )));
250+
251+ updateInference (inferenceId , TaskType .SPARSE_EMBEDDING , 2 );
252+
253+ var responseException = assertThrows (ResponseException .class , () -> updateMlNodeDeploymemnt (deploymentId , 2 ));
254+ assertThat (
255+ responseException .getMessage (),
256+ containsString (
257+ "Cannot update deployment [test_created_endpoint_from_model] as it was created by inference endpoint "
258+ + "[test_created_endpoint_from_model]. This model deployment must be updated through the inference API."
259+ )
260+ );
261+
262+ var updatedServiceSettings = getModel (inferenceId ).get ("service_settings" );
263+ assertThat (
264+ updatedServiceSettings .toString (),
265+ updatedServiceSettings ,
266+ is (Map .of ("num_allocations" , 2 , "num_threads" , 1 , "model_id" , modelId ))
267+ );
268+
269+ forceStopMlNodeDeployment (deploymentId );
270+ }
271+
272+ public void testCannotUpdateAnotherInferenceEndpointsCreatedDeployment () throws IOException {
273+ var modelId = "model_deployment_for_endpoint" ;
274+ var inferenceId = "first_endpoint_for_model_deployment" ;
275+ var deploymentId = inferenceId ;
276+
277+ CustomElandModelIT .createMlNodeTextExpansionModel (modelId , client ());
278+
279+ putModel (inferenceId , Strings .format ("""
280+ {
281+ "service": "elasticsearch",
282+ "service_settings": {
283+ "num_allocations": %s,
284+ "num_threads": %s,
285+ "model_id": "%s"
286+ }
287+ }
288+ """ , 1 , 1 , modelId ), TaskType .SPARSE_EMBEDDING );
289+
290+ var secondInferenceId = "second_endpoint_for_model_deployment" ;
291+ var putModel = putModel (secondInferenceId , endpointConfig (deploymentId ), TaskType .SPARSE_EMBEDDING );
292+ var serviceSettings = putModel .get ("service_settings" );
293+ assertThat (
294+ putModel .toString (),
295+ serviceSettings ,
296+ is (Map .of ("num_allocations" , 1 , "num_threads" , 1 , "model_id" , modelId , "deployment_id" , deploymentId ))
297+ );
298+
299+ var responseException = assertThrows (
300+ ResponseException .class ,
301+ () -> updateInference (secondInferenceId , TaskType .SPARSE_EMBEDDING , 2 )
302+ );
303+ assertThat (
304+ responseException .getMessage (),
305+ containsString (
306+ "Cannot update inference endpoint [second_endpoint_for_model_deployment] for model deployment "
307+ + "[first_endpoint_for_model_deployment] as it was created by another inference endpoint. "
308+ + "The model can only be updated using inference endpoint id [first_endpoint_for_model_deployment]."
309+ )
310+ );
311+
312+ forceStopMlNodeDeployment (deploymentId );
230313 }
231314
232315 public void testStoppingDeploymentAttachedToInferenceEndpoint () throws IOException {
@@ -300,6 +383,22 @@ private Response startMlNodeDeploymemnt(String modelId, String deploymentId) thr
300383 return client ().performRequest (request );
301384 }
302385
386+ private Response updateInference (String deploymentId , TaskType taskType , int numAllocations ) throws IOException {
387+ String endPoint = Strings .format ("/_inference/%s/%s/_update" , taskType , deploymentId );
388+
389+ var body = Strings .format ("""
390+ {
391+ "service_settings": {
392+ "num_allocations": %d
393+ }
394+ }
395+ """ , numAllocations );
396+
397+ Request request = new Request ("PUT" , endPoint );
398+ request .setJsonEntity (body );
399+ return client ().performRequest (request );
400+ }
401+
303402 private Response updateMlNodeDeploymemnt (String deploymentId , int numAllocations ) throws IOException {
304403 String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update" ;
305404
@@ -314,6 +413,16 @@ private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations
314413 return client ().performRequest (request );
315414 }
316415
416+ private Map <String , Object > updateMlNodeDeploymemnt (String deploymentId , String body ) throws IOException {
417+ String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update" ;
418+
419+ Request request = new Request ("POST" , endPoint );
420+ request .setJsonEntity (body );
421+ var response = client ().performRequest (request );
422+ assertStatusOkOrCreated (response );
423+ return entityAsMap (response );
424+ }
425+
317426 protected void stopMlNodeDeployment (String deploymentId ) throws IOException {
318427 String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop" ;
319428 Request request = new Request ("POST" , endpoint );
0 commit comments