@@ -234,4 +234,45 @@ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOE
234234 Assert .assertEquals ("sentence_embedding" , modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ().get (0 ).getName ());
235235 Assert .assertArrayEquals (new Number [] {-0.014555434 , -0.002135904 , 0.0035105038 }, modelTensorOutput .getMlModelOutputs ().get (0 ).getMlModelTensors ().get (0 ).getData ());
236236 }
237+
238+ @ Test
239+ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs_InvalidStepSize () throws IOException {
240+ exceptionRule .expect (IllegalArgumentException .class );
241+ exceptionRule .expectMessage ("Invalid parameter: input_docs_processed_step_size. It must be positive integer." );
242+ String preprocessResult1 = "{\" parameters\" : { \" input\" : \" test doc1\" } }" ;
243+ String preprocessResult2 = "{\" parameters\" : { \" input\" : \" test doc2\" } }" ;
244+ when (scriptService .compile (any (), any ()))
245+ .then (invocation -> new TestTemplateService .MockTemplateScript .Factory (preprocessResult1 ))
246+ .then (invocation -> new TestTemplateService .MockTemplateScript .Factory (preprocessResult2 ));
247+
248+ ConnectorAction predictAction = ConnectorAction .builder ()
249+ .actionType (ConnectorAction .ActionType .PREDICT )
250+ .method ("POST" )
251+ .url ("http://test.com/mock" )
252+ .preProcessFunction (MLPreProcessFunction .TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT )
253+ .postProcessFunction (MLPostProcessFunction .OPENAI_EMBEDDING )
254+ .requestBody ("{\" input\" : ${parameters.input}}" )
255+ .build ();
256+ // step size must be positive integer, here we set it as -1, should trigger IllegalArgumentException
257+ Map <String , String > parameters = ImmutableMap .of ("input_docs_processed_step_size" , "-1" );
258+ HttpConnector connector = HttpConnector .builder ().name ("test connector" ).version ("1" ).protocol ("http" ).parameters (parameters ).actions (Arrays .asList (predictAction )).build ();
259+ HttpJsonConnectorExecutor executor = spy (new HttpJsonConnectorExecutor (connector ));
260+ executor .setScriptService (scriptService );
261+ when (httpClient .execute (any ())).thenReturn (response );
262+ // model takes 2 input docs, but only output 1 embedding
263+ String modelResponse = "{\n " + " \" object\" : \" list\" ,\n " + " \" data\" : [\n " + " {\n "
264+ + " \" object\" : \" embedding\" ,\n " + " \" index\" : 0,\n " + " \" embedding\" : [\n "
265+ + " -0.014555434,\n " + " -0.002135904,\n " + " 0.0035105038\n " + " ]\n "
266+ + " } ],\n "
267+ + " \" model\" : \" text-embedding-ada-002-v2\" ,\n " + " \" usage\" : {\n " + " \" prompt_tokens\" : 5,\n "
268+ + " \" total_tokens\" : 5\n " + " }\n " + "}" ;
269+ StatusLine statusLine = new BasicStatusLine (new ProtocolVersion ("HTTP" , 1 , 1 ), 200 , "OK" );
270+ when (response .getStatusLine ()).thenReturn (statusLine );
271+ HttpEntity entity = new StringEntity (modelResponse );
272+ when (response .getEntity ()).thenReturn (entity );
273+ when (executor .getHttpClient ()).thenReturn (httpClient );
274+ when (executor .getConnector ()).thenReturn (connector );
275+ MLInputDataset inputDataSet = TextDocsInputDataSet .builder ().docs (Arrays .asList ("test doc1" , "test doc2" )).build ();
276+ ModelTensorOutput modelTensorOutput = executor .executePredict (MLInput .builder ().algorithm (FunctionName .REMOTE ).inputDataset (inputDataSet ).build ());
277+ }
237278}
0 commit comments