44import apoc .Extended ;
55import apoc .result .MapResult ;
66import apoc .util .JsonUtil ;
7+ import apoc .util .Util ;
78import com .fasterxml .jackson .core .JsonProcessingException ;
9+ import org .apache .commons .collections .MapUtils ;
10+ import org .apache .commons .lang3 .StringUtils ;
811import org .neo4j .graphdb .security .URLAccessChecker ;
912import org .neo4j .procedure .Context ;
1013import org .neo4j .procedure .Description ;
1114import org .neo4j .procedure .Name ;
1215import org .neo4j .procedure .Procedure ;
1316
1417import java .net .MalformedURLException ;
15- import java .util .HashMap ;
16- import java .util .List ;
17- import java .util .Locale ;
18- import java .util .Map ;
19- import java .util .Objects ;
18+ import java .util .*;
2019import java .util .function .BiFunction ;
2120import java .util .function .Function ;
21+ import java .util .function .Supplier ;
2222import java .util .stream .Collectors ;
2323import java .util .stream .Stream ;
2424
@@ -35,6 +35,7 @@ public class OpenAI {
3535 public static final String JSON_PATH_CONF_KEY = "jsonPath" ;
3636 public static final String PATH_CONF_KEY = "path" ;
3737 public static final String GPT_4O_MODEL = "gpt-4o" ;
38+ public static final String FAIL_ON_ERROR_CONF = "failOnError" ;
3839
3940 @ Context
4041 public ApocConfig apocConfig ;
@@ -147,6 +148,10 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
147148 "model": "text-embedding-ada-002",
148149 "usage": { "prompt_tokens": 8, "total_tokens": 8 } }
149150 */
151+ boolean failOnError = isFailOnError (configuration );
152+ if (checkNullInput (texts , failOnError )) return Stream .empty ();
153+ texts = texts .stream ().filter (StringUtils ::isNotBlank ).toList ();
154+ if (checkEmptyInput (texts , failOnError )) return Stream .empty ();
150155 return getEmbeddingResult (texts , apiKey , configuration , apocConfig , urlAccessChecker ,
151156 (map , text ) -> {
152157 Long index = (Long ) map .get ("index" );
@@ -156,6 +161,7 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
156161 );
157162 }
158163
164+
159165 static <T > Stream <T > getEmbeddingResult (List <String > texts , String apiKey , Map <String , Object > configuration , ApocConfig apocConfig , URLAccessChecker urlAccessChecker ,
160166 BiFunction <Map , String , T > embeddingMapping , Function <String , T > nullMapping ) throws JsonProcessingException , MalformedURLException {
161167 if (texts == null ) {
@@ -194,19 +200,19 @@ public Stream<MapResult> completion(@Name("prompt") String prompt, @Name("api_ke
194200 "usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 }
195201 }
196202 */
197- if (prompt == null ) {
198- throw new RuntimeException (ERROR_NULL_INPUT );
199- }
203+ boolean failOnError = isFailOnError (configuration );
204+ if (checkBlankInput (prompt , failOnError )) return Stream .empty ();
200205 return executeRequest (apiKey , configuration , "completions" , "gpt-3.5-turbo-instruct" , "prompt" , prompt , "$" , apocConfig , urlAccessChecker )
201206 .map (v -> (Map <String ,Object >)v ).map (MapResult ::new );
202207 }
203208
204209 @ Procedure ("apoc.ml.openai.chat" )
205210 @ Description ("apoc.ml.openai.chat(messages, api_key, configuration]) - prompts the completion API" )
206211 public Stream <MapResult > chatCompletion (@ Name ("messages" ) List <Map <String , Object >> messages , @ Name ("api_key" ) String apiKey , @ Name (value = "configuration" , defaultValue = "{}" ) Map <String , Object > configuration ) throws Exception {
207- if (messages == null ) {
208- throw new RuntimeException (ERROR_NULL_INPUT );
209- }
212+ boolean failOnError = isFailOnError (configuration );
213+ if (checkNullInput (messages , failOnError )) return Stream .empty ();
214+ messages = messages .stream ().filter (MapUtils ::isNotEmpty ).toList ();
215+ if (checkEmptyInput (messages , failOnError )) return Stream .empty ();
210216 configuration .putIfAbsent ("model" , GPT_4O_MODEL );
211217 return executeRequest (apiKey , configuration , "chat/completions" , (String ) configuration .get ("model" ), "messages" , messages , "$" , apocConfig , urlAccessChecker )
212218 .map (v -> (Map <String ,Object >)v ).map (MapResult ::new );
@@ -220,4 +226,32 @@ public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Objec
220226 } ] }
221227 */
222228 }
229+
230+ private static boolean isFailOnError (Map <String , Object > configuration ) {
231+ return Util .toBoolean (configuration .getOrDefault (FAIL_ON_ERROR_CONF , true ));
232+ }
233+
234+ static boolean checkNullInput (Object input , boolean failOnError ) {
235+ return checkInput (failOnError , () -> Objects .isNull (input ));
236+ }
237+
238+ static boolean checkEmptyInput (Collection <?> input , boolean failOnError ) {
239+ return checkInput (failOnError , () -> input .isEmpty ());
240+ }
241+
242+ static boolean checkBlankInput (String input , boolean failOnError ) {
243+ return checkInput (failOnError , () -> StringUtils .isBlank (input ));
244+ }
245+
246+ private static boolean checkInput (
247+ boolean failOnError ,
248+ Supplier <Boolean > checkFunction
249+ ){
250+ if (checkFunction .get ()) {
251+ if (failOnError ) throw new RuntimeException (ERROR_NULL_INPUT );
252+ return true ;
253+ }
254+ return false ;
255+ }
256+
223257}
0 commit comments