@@ -124,17 +124,16 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep
124124 Response response = createConnector (completionModelConnectorEntity );
125125 Map responseMap = parseResponseToMap (response );
126126 String connectorId = (String ) responseMap .get ("connector_id" );
127+
127128 response = registerRemoteModelWithLocalRegexGuardrails ("openAI-GPT-3.5 completions" , connectorId );
128129 responseMap = parseResponseToMap (response );
129- String taskId = (String ) responseMap .get ("task_id" );
130- waitForTask (taskId , MLTaskState .COMPLETED );
131- response = getTask (taskId );
132- responseMap = parseResponseToMap (response );
133130 String modelId = (String ) responseMap .get ("model_id" );
131+
134132 response = deployRemoteModel (modelId );
135133 responseMap = parseResponseToMap (response );
136- taskId = (String ) responseMap .get ("task_id" );
134+ String taskId = (String ) responseMap .get ("task_id" );
137135 waitForTask (taskId , MLTaskState .COMPLETED );
136+
138137 String predictInput = "{\n " + " \" parameters\" : {\n " + " \" prompt\" : \" Say this is a test\" \n " + " }\n " + "}" ;
139138 response = predictRemoteModel (modelId , predictInput );
140139 responseMap = parseResponseToMap (response );
@@ -144,6 +143,7 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep
144143 responseMap = (Map ) responseList .get (0 );
145144 responseMap = (Map ) responseMap .get ("dataAsMap" );
146145 responseList = (List ) responseMap .get ("choices" );
146+
147147 if (responseList == null ) {
148148 assertTrue (checkThrottlingOpenAI (responseMap ));
149149 return ;
@@ -160,18 +160,18 @@ public void testPredictRemoteModelFailed() throws IOException, InterruptedExcept
160160 exceptionRule .expect (ResponseException .class );
161161 exceptionRule .expectMessage ("guardrails triggered for user input" );
162162 Response response = createConnector (completionModelConnectorEntity );
163+
163164 Map responseMap = parseResponseToMap (response );
164165 String connectorId = (String ) responseMap .get ("connector_id" );
166+
165167 response = registerRemoteModelWithLocalRegexGuardrails ("openAI-GPT-3.5 completions" , connectorId );
166168 responseMap = parseResponseToMap (response );
167- String taskId = (String ) responseMap .get ("task_id" );
168- waitForTask (taskId , MLTaskState .COMPLETED );
169- response = getTask (taskId );
170- responseMap = parseResponseToMap (response );
171169 String modelId = (String ) responseMap .get ("model_id" );
170+
172171 response = deployRemoteModel (modelId );
173172 responseMap = parseResponseToMap (response );
174- taskId = (String ) responseMap .get ("task_id" );
173+ String taskId = (String ) responseMap .get ("task_id" );
174+
175175 waitForTask (taskId , MLTaskState .COMPLETED );
176176 String predictInput = "{\n " + " \" parameters\" : {\n " + " \" prompt\" : \" Say this is a test of stop word.\" \n " + " }\n " + "}" ;
177177 predictRemoteModel (modelId , predictInput );
@@ -187,17 +187,16 @@ public void testPredictRemoteModelFailedNonType() throws IOException, Interrupte
187187 Response response = createConnector (completionModelConnectorEntity );
188188 Map responseMap = parseResponseToMap (response );
189189 String connectorId = (String ) responseMap .get ("connector_id" );
190+
190191 response = registerRemoteModelNonTypeGuardrails ("openAI-GPT-3.5 completions" , connectorId );
191192 responseMap = parseResponseToMap (response );
192- String taskId = (String ) responseMap .get ("task_id" );
193- waitForTask (taskId , MLTaskState .COMPLETED );
194- response = getTask (taskId );
195- responseMap = parseResponseToMap (response );
196193 String modelId = (String ) responseMap .get ("model_id" );
194+
197195 response = deployRemoteModel (modelId );
198196 responseMap = parseResponseToMap (response );
199- taskId = (String ) responseMap .get ("task_id" );
197+ String taskId = (String ) responseMap .get ("task_id" );
200198 waitForTask (taskId , MLTaskState .COMPLETED );
199+
201200 String predictInput = "{\n " + " \" parameters\" : {\n " + " \" prompt\" : \" Say this is a test of stop word.\" \n " + " }\n " + "}" ;
202201 predictRemoteModel (modelId , predictInput );
203202 }
@@ -211,17 +210,16 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException
211210 Response response = createConnector (completionModelConnectorEntityWithGuardrail );
212211 Map responseMap = parseResponseToMap (response );
213212 String guardrailConnectorId = (String ) responseMap .get ("connector_id" );
213+
214214 response = registerRemoteModel ("guardrail model group" , "openAI-GPT-3.5 completions" , guardrailConnectorId );
215215 responseMap = parseResponseToMap (response );
216- String taskId = (String ) responseMap .get ("task_id" );
217- waitForTask (taskId , MLTaskState .COMPLETED );
218- response = getTask (taskId );
219- responseMap = parseResponseToMap (response );
220216 String guardrailModelId = (String ) responseMap .get ("model_id" );
217+
221218 response = deployRemoteModel (guardrailModelId );
222219 responseMap = parseResponseToMap (response );
223- taskId = (String ) responseMap .get ("task_id" );
220+ String taskId = (String ) responseMap .get ("task_id" );
224221 waitForTask (taskId , MLTaskState .COMPLETED );
222+
225223 // Check the response from guardrails model that should be "accept".
226224 String predictInput = "{\n " + " \" parameters\" : {\n " + " \" question\" : \" hello\" \n " + " }\n " + "}" ;
227225 response = predictRemoteModel (guardrailModelId , predictInput );
@@ -233,21 +231,21 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException
233231 responseMap = (Map ) responseMap .get ("dataAsMap" );
234232 String validationResult = (String ) responseMap .get ("response" );
235233 Assert .assertTrue (validateRegex (validationResult , acceptRegex ));
234+
236235 // Create predict model.
237236 response = createConnector (completionModelConnectorEntity );
238237 responseMap = parseResponseToMap (response );
239238 String connectorId = (String ) responseMap .get ("connector_id" );
239+
240240 response = registerRemoteModelWithModelGuardrails ("openAI with guardrails" , connectorId , guardrailModelId );
241241 responseMap = parseResponseToMap (response );
242- taskId = (String ) responseMap .get ("task_id" );
243- waitForTask (taskId , MLTaskState .COMPLETED );
244- response = getTask (taskId );
245- responseMap = parseResponseToMap (response );
246242 String modelId = (String ) responseMap .get ("model_id" );
243+
247244 response = deployRemoteModel (modelId );
248245 responseMap = parseResponseToMap (response );
249246 taskId = (String ) responseMap .get ("task_id" );
250247 waitForTask (taskId , MLTaskState .COMPLETED );
248+
251249 // Predict.
252250 predictInput = "{\n "
253251 + " \" parameters\" : {\n "
@@ -282,17 +280,17 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException,
282280 Response response = createConnector (completionModelConnectorEntityWithGuardrail );
283281 Map responseMap = parseResponseToMap (response );
284282 String guardrailConnectorId = (String ) responseMap .get ("connector_id" );
283+
284+ //Create the model ID
285285 response = registerRemoteModel ("guardrail model group" , "openAI-GPT-3.5 completions" , guardrailConnectorId );
286286 responseMap = parseResponseToMap (response );
287- String taskId = (String ) responseMap .get ("task_id" );
288- waitForTask (taskId , MLTaskState .COMPLETED );
289- response = getTask (taskId );
290- responseMap = parseResponseToMap (response );
291287 String guardrailModelId = (String ) responseMap .get ("model_id" );
288+
292289 response = deployRemoteModel (guardrailModelId );
293290 responseMap = parseResponseToMap (response );
294- taskId = (String ) responseMap .get ("task_id" );
291+ String taskId = (String ) responseMap .get ("task_id" );
295292 waitForTask (taskId , MLTaskState .COMPLETED );
293+
296294 // Check the response from guardrails model that should be "reject".
297295 String predictInput = "{\n " + " \" parameters\" : {\n " + " \" question\" : \" I will be executed or tortured.\" \n " + " }\n " + "}" ;
298296 response = predictRemoteModel (guardrailModelId , predictInput );
@@ -304,17 +302,16 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException,
304302 responseMap = (Map ) responseMap .get ("dataAsMap" );
305303 String validationResult = (String ) responseMap .get ("response" );
306304 Assert .assertTrue (validateRegex (validationResult , rejectRegex ));
305+
307306 // Create predict model.
308307 response = createConnector (completionModelConnectorEntity );
309308 responseMap = parseResponseToMap (response );
310309 String connectorId = (String ) responseMap .get ("connector_id" );
310+
311311 response = registerRemoteModelWithModelGuardrails ("openAI with guardrails" , connectorId , guardrailModelId );
312312 responseMap = parseResponseToMap (response );
313- taskId = (String ) responseMap .get ("task_id" );
314- waitForTask (taskId , MLTaskState .COMPLETED );
315- response = getTask (taskId );
316- responseMap = parseResponseToMap (response );
317313 String modelId = (String ) responseMap .get ("model_id" );
314+
318315 response = deployRemoteModel (modelId );
319316 responseMap = parseResponseToMap (response );
320317 taskId = (String ) responseMap .get ("task_id" );
0 commit comments