Skip to content

Commit 5da2d91

Browse files
authored
validate step size (#1587)
Signed-off-by: Yaliang Wu <[email protected]>
1 parent 073b372 commit 5da2d91

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,16 @@ default ModelTensorOutput executePredict(MLInput mlInput) {
4242
if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) {
4343
tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size();
4444
}
45-
// This is to support some model which takes N text docs and embedding size is less than N-1.
45+
// This is to support some model which takes N text docs and embedding size is less than N.
4646
// We need to tell executor what's the step size for each model run.
4747
Map<String, String> parameters = getConnector().getParameters();
4848
if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) {
49-
processedDocs += Integer.parseInt(parameters.get("input_docs_processed_step_size"));
49+
int stepSize = Integer.parseInt(parameters.get("input_docs_processed_step_size"));
50+
// We need to check the parameter on runtime as parameter can be passed into predict request
51+
if (stepSize <= 0) {
52+
throw new IllegalArgumentException("Invalid parameter: input_docs_processed_step_size. It must be positive integer.");
53+
}
54+
processedDocs += stepSize;
5055
} else {
5156
processedDocs += Math.max(tensorCount, 1);
5257
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,4 +234,45 @@ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOE
234234
Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
235235
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
236236
}
237+
238+
@Test
239+
public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs_InvalidStepSize() throws IOException {
240+
exceptionRule.expect(IllegalArgumentException.class);
241+
exceptionRule.expectMessage("Invalid parameter: input_docs_processed_step_size. It must be positive integer.");
242+
String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }";
243+
String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }";
244+
when(scriptService.compile(any(), any()))
245+
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1))
246+
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2));
247+
248+
ConnectorAction predictAction = ConnectorAction.builder()
249+
.actionType(ConnectorAction.ActionType.PREDICT)
250+
.method("POST")
251+
.url("http://test.com/mock")
252+
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
253+
.postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
254+
.requestBody("{\"input\": ${parameters.input}}")
255+
.build();
256+
// step size must be positive integer, here we set it as -1, should trigger IllegalArgumentException
257+
Map<String, String> parameters = ImmutableMap.of("input_docs_processed_step_size", "-1");
258+
HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
259+
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
260+
executor.setScriptService(scriptService);
261+
when(httpClient.execute(any())).thenReturn(response);
262+
// model takes 2 input docs, but only output 1 embedding
263+
String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n"
264+
+ " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n"
265+
+ " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n"
266+
+ " } ],\n"
267+
+ " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n"
268+
+ " \"total_tokens\": 5\n" + " }\n" + "}";
269+
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
270+
when(response.getStatusLine()).thenReturn(statusLine);
271+
HttpEntity entity = new StringEntity(modelResponse);
272+
when(response.getEntity()).thenReturn(entity);
273+
when(executor.getHttpClient()).thenReturn(httpClient);
274+
when(executor.getConnector()).thenReturn(connector);
275+
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
276+
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
277+
}
237278
}

0 commit comments

Comments
 (0)