Skip to content

Commit f888c27

Browse files
committed
fix uneeded call to get model_id for task api within RestMLGuardrailsIT
Following opensearch-project#3244 this IT called the task api to check the model id again however this is redundant. Instead one can directly pull the model_id upon creating the model group. Manual testing was done to see that the behavior is intact, this should help reduce the calls within a IT to make it less flaky Signed-off-by: Brian Flores <[email protected]>
1 parent 1d30671 commit f888c27

File tree

2 files changed

+31
-34
lines changed

2 files changed

+31
-34
lines changed

plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,7 @@ public void waitForTask(String taskId, MLTaskState targetState) throws Interrupt
977977
}
978978
return taskDone.get();
979979
}, CUSTOM_MODEL_TIMEOUT, TimeUnit.SECONDS);
980-
assertTrue(taskDone.get());
980+
assertTrue(String.format(Locale.ROOT, "Task Id %s could not get to %s state", taskId, targetState.name()), taskDone.get());
981981
}
982982

983983
public String registerConnector(String createConnectorInput) throws IOException, InterruptedException {

plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -124,17 +124,16 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep
124124
Response response = createConnector(completionModelConnectorEntity);
125125
Map responseMap = parseResponseToMap(response);
126126
String connectorId = (String) responseMap.get("connector_id");
127+
127128
response = registerRemoteModelWithLocalRegexGuardrails("openAI-GPT-3.5 completions", connectorId);
128129
responseMap = parseResponseToMap(response);
129-
String taskId = (String) responseMap.get("task_id");
130-
waitForTask(taskId, MLTaskState.COMPLETED);
131-
response = getTask(taskId);
132-
responseMap = parseResponseToMap(response);
133130
String modelId = (String) responseMap.get("model_id");
131+
134132
response = deployRemoteModel(modelId);
135133
responseMap = parseResponseToMap(response);
136-
taskId = (String) responseMap.get("task_id");
134+
String taskId = (String) responseMap.get("task_id");
137135
waitForTask(taskId, MLTaskState.COMPLETED);
136+
138137
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}";
139138
response = predictRemoteModel(modelId, predictInput);
140139
responseMap = parseResponseToMap(response);
@@ -144,6 +143,7 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep
144143
responseMap = (Map) responseList.get(0);
145144
responseMap = (Map) responseMap.get("dataAsMap");
146145
responseList = (List) responseMap.get("choices");
146+
147147
if (responseList == null) {
148148
assertTrue(checkThrottlingOpenAI(responseMap));
149149
return;
@@ -160,18 +160,18 @@ public void testPredictRemoteModelFailed() throws IOException, InterruptedExcept
160160
exceptionRule.expect(ResponseException.class);
161161
exceptionRule.expectMessage("guardrails triggered for user input");
162162
Response response = createConnector(completionModelConnectorEntity);
163+
163164
Map responseMap = parseResponseToMap(response);
164165
String connectorId = (String) responseMap.get("connector_id");
166+
165167
response = registerRemoteModelWithLocalRegexGuardrails("openAI-GPT-3.5 completions", connectorId);
166168
responseMap = parseResponseToMap(response);
167-
String taskId = (String) responseMap.get("task_id");
168-
waitForTask(taskId, MLTaskState.COMPLETED);
169-
response = getTask(taskId);
170-
responseMap = parseResponseToMap(response);
171169
String modelId = (String) responseMap.get("model_id");
170+
172171
response = deployRemoteModel(modelId);
173172
responseMap = parseResponseToMap(response);
174-
taskId = (String) responseMap.get("task_id");
173+
String taskId = (String) responseMap.get("task_id");
174+
175175
waitForTask(taskId, MLTaskState.COMPLETED);
176176
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}";
177177
predictRemoteModel(modelId, predictInput);
@@ -187,17 +187,16 @@ public void testPredictRemoteModelFailedNonType() throws IOException, Interrupte
187187
Response response = createConnector(completionModelConnectorEntity);
188188
Map responseMap = parseResponseToMap(response);
189189
String connectorId = (String) responseMap.get("connector_id");
190+
190191
response = registerRemoteModelNonTypeGuardrails("openAI-GPT-3.5 completions", connectorId);
191192
responseMap = parseResponseToMap(response);
192-
String taskId = (String) responseMap.get("task_id");
193-
waitForTask(taskId, MLTaskState.COMPLETED);
194-
response = getTask(taskId);
195-
responseMap = parseResponseToMap(response);
196193
String modelId = (String) responseMap.get("model_id");
194+
197195
response = deployRemoteModel(modelId);
198196
responseMap = parseResponseToMap(response);
199-
taskId = (String) responseMap.get("task_id");
197+
String taskId = (String) responseMap.get("task_id");
200198
waitForTask(taskId, MLTaskState.COMPLETED);
199+
201200
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}";
202201
predictRemoteModel(modelId, predictInput);
203202
}
@@ -211,17 +210,16 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException
211210
Response response = createConnector(completionModelConnectorEntityWithGuardrail);
212211
Map responseMap = parseResponseToMap(response);
213212
String guardrailConnectorId = (String) responseMap.get("connector_id");
213+
214214
response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId);
215215
responseMap = parseResponseToMap(response);
216-
String taskId = (String) responseMap.get("task_id");
217-
waitForTask(taskId, MLTaskState.COMPLETED);
218-
response = getTask(taskId);
219-
responseMap = parseResponseToMap(response);
220216
String guardrailModelId = (String) responseMap.get("model_id");
217+
221218
response = deployRemoteModel(guardrailModelId);
222219
responseMap = parseResponseToMap(response);
223-
taskId = (String) responseMap.get("task_id");
220+
String taskId = (String) responseMap.get("task_id");
224221
waitForTask(taskId, MLTaskState.COMPLETED);
222+
225223
// Check the response from guardrails model that should be "accept".
226224
String predictInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}";
227225
response = predictRemoteModel(guardrailModelId, predictInput);
@@ -233,21 +231,21 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException
233231
responseMap = (Map) responseMap.get("dataAsMap");
234232
String validationResult = (String) responseMap.get("response");
235233
Assert.assertTrue(validateRegex(validationResult, acceptRegex));
234+
236235
// Create predict model.
237236
response = createConnector(completionModelConnectorEntity);
238237
responseMap = parseResponseToMap(response);
239238
String connectorId = (String) responseMap.get("connector_id");
239+
240240
response = registerRemoteModelWithModelGuardrails("openAI with guardrails", connectorId, guardrailModelId);
241241
responseMap = parseResponseToMap(response);
242-
taskId = (String) responseMap.get("task_id");
243-
waitForTask(taskId, MLTaskState.COMPLETED);
244-
response = getTask(taskId);
245-
responseMap = parseResponseToMap(response);
246242
String modelId = (String) responseMap.get("model_id");
243+
247244
response = deployRemoteModel(modelId);
248245
responseMap = parseResponseToMap(response);
249246
taskId = (String) responseMap.get("task_id");
250247
waitForTask(taskId, MLTaskState.COMPLETED);
248+
251249
// Predict.
252250
predictInput = "{\n"
253251
+ " \"parameters\": {\n"
@@ -282,17 +280,17 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException,
282280
Response response = createConnector(completionModelConnectorEntityWithGuardrail);
283281
Map responseMap = parseResponseToMap(response);
284282
String guardrailConnectorId = (String) responseMap.get("connector_id");
283+
284+
//Create the model ID
285285
response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId);
286286
responseMap = parseResponseToMap(response);
287-
String taskId = (String) responseMap.get("task_id");
288-
waitForTask(taskId, MLTaskState.COMPLETED);
289-
response = getTask(taskId);
290-
responseMap = parseResponseToMap(response);
291287
String guardrailModelId = (String) responseMap.get("model_id");
288+
292289
response = deployRemoteModel(guardrailModelId);
293290
responseMap = parseResponseToMap(response);
294-
taskId = (String) responseMap.get("task_id");
291+
String taskId = (String) responseMap.get("task_id");
295292
waitForTask(taskId, MLTaskState.COMPLETED);
293+
296294
// Check the response from guardrails model that should be "reject".
297295
String predictInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"I will be executed or tortured.\"\n" + " }\n" + "}";
298296
response = predictRemoteModel(guardrailModelId, predictInput);
@@ -304,17 +302,16 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException,
304302
responseMap = (Map) responseMap.get("dataAsMap");
305303
String validationResult = (String) responseMap.get("response");
306304
Assert.assertTrue(validateRegex(validationResult, rejectRegex));
305+
307306
// Create predict model.
308307
response = createConnector(completionModelConnectorEntity);
309308
responseMap = parseResponseToMap(response);
310309
String connectorId = (String) responseMap.get("connector_id");
310+
311311
response = registerRemoteModelWithModelGuardrails("openAI with guardrails", connectorId, guardrailModelId);
312312
responseMap = parseResponseToMap(response);
313-
taskId = (String) responseMap.get("task_id");
314-
waitForTask(taskId, MLTaskState.COMPLETED);
315-
response = getTask(taskId);
316-
responseMap = parseResponseToMap(response);
317313
String modelId = (String) responseMap.get("model_id");
314+
318315
response = deployRemoteModel(modelId);
319316
responseMap = parseResponseToMap(response);
320317
taskId = (String) responseMap.get("task_id");

0 commit comments

Comments
 (0)