3939
4040import java .io .IOException ;
4141import java .util .Arrays ;
42+ import java .util .Map ;
4243
4344import static org .mockito .ArgumentMatchers .any ;
4445import static org .mockito .Mockito .mock ;
@@ -165,7 +166,7 @@ public void executePredict_TextDocsInput() throws IOException {
165166 .postProcessFunction (MLPostProcessFunction .OPENAI_EMBEDDING )
166167 .requestBody ("{\" input\" : ${parameters.input}}" )
167168 .build ();
168- Connector connector = HttpConnector .builder ().name ("test connector" ).version ("1" ).protocol ("http" ).actions (Arrays .asList (predictAction )).build ();
169+ HttpConnector connector = HttpConnector .builder ().name ("test connector" ).version ("1" ).protocol ("http" ).actions (Arrays .asList (predictAction )).build ();
169170 HttpJsonConnectorExecutor executor = spy (new HttpJsonConnectorExecutor (connector ));
170171 executor .setScriptService (scriptService );
171172 when (httpClient .execute (any ())).thenReturn (response );
@@ -182,6 +183,7 @@ public void executePredict_TextDocsInput() throws IOException {
182183 HttpEntity entity = new StringEntity (modelResponse );
183184 when (response .getEntity ()).thenReturn (entity );
184185 when (executor .getHttpClient ()).thenReturn (httpClient );
186+ when (executor .getConnector ()).thenReturn (connector );
185187 MLInputDataset inputDataSet = TextDocsInputDataSet .builder ().docs (Arrays .asList ("test doc1" , "test doc2" )).build ();
186188 ModelTensorOutput modelTensorOutput = executor .executePredict (MLInput .builder ().algorithm (FunctionName .REMOTE ).inputDataset (inputDataSet ).build ());
187189 Assert .assertEquals (1 , modelTensorOutput .getMlModelOutputs ().size ());
@@ -190,4 +192,46 @@ public void executePredict_TextDocsInput() throws IOException {
190192 Assert .assertArrayEquals (new Number [] {-0.014555434 , -0.002135904 , 0.0035105038 }, modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ().get (0 ).getData ());
191193 Assert .assertArrayEquals (new Number [] {-0.014555434 , -0.002135904 , 0.0035105038 }, modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ().get (1 ).getData ());
192194 }
195+
196+ @ Test
197+ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs () throws IOException {
198+ String preprocessResult1 = "{\" parameters\" : { \" input\" : \" test doc1\" } }" ;
199+ String preprocessResult2 = "{\" parameters\" : { \" input\" : \" test doc2\" } }" ;
200+ when (scriptService .compile (any (), any ()))
201+ .then (invocation -> new TestTemplateService .MockTemplateScript .Factory (preprocessResult1 ))
202+ .then (invocation -> new TestTemplateService .MockTemplateScript .Factory (preprocessResult2 ));
203+
204+ ConnectorAction predictAction = ConnectorAction .builder ()
205+ .actionType (ConnectorAction .ActionType .PREDICT )
206+ .method ("POST" )
207+ .url ("http://test.com/mock" )
208+ .preProcessFunction (MLPreProcessFunction .TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT )
209+ .postProcessFunction (MLPostProcessFunction .OPENAI_EMBEDDING )
210+ .requestBody ("{\" input\" : ${parameters.input}}" )
211+ .build ();
212+ Map <String , String > parameters = ImmutableMap .of ("input_docs_processed_step_size" , "2" );
213+ HttpConnector connector = HttpConnector .builder ().name ("test connector" ).version ("1" ).protocol ("http" ).parameters (parameters ).actions (Arrays .asList (predictAction )).build ();
214+ HttpJsonConnectorExecutor executor = spy (new HttpJsonConnectorExecutor (connector ));
215+ executor .setScriptService (scriptService );
216+ when (httpClient .execute (any ())).thenReturn (response );
217+ // model takes 2 input docs, but only output 1 embedding
218+ String modelResponse = "{\n " + " \" object\" : \" list\" ,\n " + " \" data\" : [\n " + " {\n "
219+ + " \" object\" : \" embedding\" ,\n " + " \" index\" : 0,\n " + " \" embedding\" : [\n "
220+ + " -0.014555434,\n " + " -0.002135904,\n " + " 0.0035105038\n " + " ]\n "
221+ + " } ],\n "
222+ + " \" model\" : \" text-embedding-ada-002-v2\" ,\n " + " \" usage\" : {\n " + " \" prompt_tokens\" : 5,\n "
223+ + " \" total_tokens\" : 5\n " + " }\n " + "}" ;
224+ StatusLine statusLine = new BasicStatusLine (new ProtocolVersion ("HTTP" , 1 , 1 ), 200 , "OK" );
225+ when (response .getStatusLine ()).thenReturn (statusLine );
226+ HttpEntity entity = new StringEntity (modelResponse );
227+ when (response .getEntity ()).thenReturn (entity );
228+ when (executor .getHttpClient ()).thenReturn (httpClient );
229+ when (executor .getConnector ()).thenReturn (connector );
230+ MLInputDataset inputDataSet = TextDocsInputDataSet .builder ().docs (Arrays .asList ("test doc1" , "test doc2" )).build ();
231+ ModelTensorOutput modelTensorOutput = executor .executePredict (MLInput .builder ().algorithm (FunctionName .REMOTE ).inputDataset (inputDataSet ).build ());
232+ Assert .assertEquals (1 , modelTensorOutput .getMlModelOutputs ().size ());
233+ Assert .assertEquals (1 , modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ().size ());
234+ Assert .assertEquals ("sentence_embedding" , modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ().get (0 ).getName ());
235+ Assert .assertArrayEquals (new Number [] {-0.014555434 , -0.002135904 , 0.0035105038 }, modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ().get (0 ).getData ());
236+ }
193237}
0 commit comments