Skip to content

Commit 073b372

Browse files
authored
support step size for embedding model which outputs less embeddings (#1586)
* support step size for embedding model which outputs less embeddings Signed-off-by: Yaliang Wu <[email protected]> * tune parameter name Signed-off-by: Yaliang Wu <[email protected]> * fine tune processed doc to always respect step size Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: Yaliang Wu <[email protected]>
1 parent c6bef30 commit 073b372

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.HashMap;
2323
import java.util.List;
2424
import java.util.Map;
25+
import java.util.Optional;
2526

2627
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;
2728

@@ -41,7 +42,14 @@ default ModelTensorOutput executePredict(MLInput mlInput) {
4142
if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) {
4243
tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size();
4344
}
44-
processedDocs += Math.max(tensorCount, 1);
45+
// This is to support some model which takes N text docs and embedding size is less than N-1.
46+
// We need to tell executor what's the step size for each model run.
47+
Map<String, String> parameters = getConnector().getParameters();
48+
if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) {
49+
processedDocs += Integer.parseInt(parameters.get("input_docs_processed_step_size"));
50+
} else {
51+
processedDocs += Math.max(tensorCount, 1);
52+
}
4553
tensorOutputs.addAll(tempTensorOutputs);
4654
}
4755

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
import java.io.IOException;
4141
import java.util.Arrays;
42+
import java.util.Map;
4243

4344
import static org.mockito.ArgumentMatchers.any;
4445
import static org.mockito.Mockito.mock;
@@ -165,7 +166,7 @@ public void executePredict_TextDocsInput() throws IOException {
165166
.postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
166167
.requestBody("{\"input\": ${parameters.input}}")
167168
.build();
168-
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
169+
HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
169170
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
170171
executor.setScriptService(scriptService);
171172
when(httpClient.execute(any())).thenReturn(response);
@@ -182,6 +183,7 @@ public void executePredict_TextDocsInput() throws IOException {
182183
HttpEntity entity = new StringEntity(modelResponse);
183184
when(response.getEntity()).thenReturn(entity);
184185
when(executor.getHttpClient()).thenReturn(httpClient);
186+
when(executor.getConnector()).thenReturn(connector);
185187
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
186188
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
187189
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
@@ -190,4 +192,46 @@ public void executePredict_TextDocsInput() throws IOException {
190192
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
191193
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData());
192194
}
195+
196+
@Test
197+
public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOException {
198+
String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }";
199+
String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }";
200+
when(scriptService.compile(any(), any()))
201+
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1))
202+
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2));
203+
204+
ConnectorAction predictAction = ConnectorAction.builder()
205+
.actionType(ConnectorAction.ActionType.PREDICT)
206+
.method("POST")
207+
.url("http://test.com/mock")
208+
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
209+
.postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
210+
.requestBody("{\"input\": ${parameters.input}}")
211+
.build();
212+
Map<String, String> parameters = ImmutableMap.of("input_docs_processed_step_size", "2");
213+
HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
214+
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
215+
executor.setScriptService(scriptService);
216+
when(httpClient.execute(any())).thenReturn(response);
217+
// model takes 2 input docs, but only output 1 embedding
218+
String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n"
219+
+ " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n"
220+
+ " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n"
221+
+ " } ],\n"
222+
+ " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n"
223+
+ " \"total_tokens\": 5\n" + " }\n" + "}";
224+
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
225+
when(response.getStatusLine()).thenReturn(statusLine);
226+
HttpEntity entity = new StringEntity(modelResponse);
227+
when(response.getEntity()).thenReturn(entity);
228+
when(executor.getHttpClient()).thenReturn(httpClient);
229+
when(executor.getConnector()).thenReturn(connector);
230+
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
231+
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
232+
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
233+
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
234+
Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
235+
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
236+
}
193237
}

0 commit comments

Comments
 (0)