Skip to content

Commit d6b31c3

Browse files
authored
Fixes #4156: Improves handling of empty or blank input for openai procedures (#4228)
* Fixes #4156: Improves handling of empty or blank input for openai procedures * fix tests * changed boolean conditions
1 parent 4e2f977 commit d6b31c3

File tree

5 files changed

+96
-27
lines changed

5 files changed

+96
-27
lines changed

docs/asciidoc/modules/ROOT/pages/ml/openai.adoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ If present, they take precedence over the analogous APOC configs.
3636
By default, is `/embeddings`, `/completions` and `/chat/completions` for respectively the `apoc.ml.openai.embedding`, `apoc.ml.openai.completion` and `apoc.ml.openai.chat` procedures.
3737
| jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] of the response.
3838
The default is `$` for the `apoc.ml.openai.chat` and `apoc.ml.openai.completion` procedures, and `$.data` for the `apoc.ml.openai.embedding` procedure.
39+
| failOnError | If true (default), the procedure fails in case of empty, blank or null input
3940
|===
4041

4142

extended/src/main/java/apoc/ml/MLUtil.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package apoc.ml;
22

33
public class MLUtil {
4-
public static final String ERROR_NULL_INPUT = "The input provided is null. Please specify a valid input";
4+
public static final String ERROR_NULL_INPUT = "Null, blank or empty input provided. Please specify a valid input";
55

66
public static final String ENDPOINT_CONF_KEY = "endpoint";
77
public static final String API_VERSION_CONF_KEY = "apiVersion";

extended/src/main/java/apoc/ml/OpenAI.java

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@
44
import apoc.Extended;
55
import apoc.result.MapResult;
66
import apoc.util.JsonUtil;
7+
import apoc.util.Util;
78
import com.fasterxml.jackson.core.JsonProcessingException;
9+
import org.apache.commons.collections.MapUtils;
10+
import org.apache.commons.lang3.StringUtils;
811
import org.neo4j.graphdb.security.URLAccessChecker;
912
import org.neo4j.procedure.Context;
1013
import org.neo4j.procedure.Description;
1114
import org.neo4j.procedure.Name;
1215
import org.neo4j.procedure.Procedure;
1316

1417
import java.net.MalformedURLException;
15-
import java.util.HashMap;
16-
import java.util.List;
17-
import java.util.Locale;
18-
import java.util.Map;
19-
import java.util.Objects;
18+
import java.util.*;
2019
import java.util.function.BiFunction;
2120
import java.util.function.Function;
21+
import java.util.function.Supplier;
2222
import java.util.stream.Collectors;
2323
import java.util.stream.Stream;
2424

@@ -35,6 +35,7 @@ public class OpenAI {
3535
public static final String JSON_PATH_CONF_KEY = "jsonPath";
3636
public static final String PATH_CONF_KEY = "path";
3737
public static final String GPT_4O_MODEL = "gpt-4o";
38+
public static final String FAIL_ON_ERROR_CONF = "failOnError";
3839

3940
@Context
4041
public ApocConfig apocConfig;
@@ -147,6 +148,10 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
147148
"model": "text-embedding-ada-002",
148149
"usage": { "prompt_tokens": 8, "total_tokens": 8 } }
149150
*/
151+
boolean failOnError = isFailOnError(configuration);
152+
if (checkNullInput(texts, failOnError)) return Stream.empty();
153+
texts = texts.stream().filter(StringUtils::isNotBlank).toList();
154+
if (checkEmptyInput(texts, failOnError)) return Stream.empty();
150155
return getEmbeddingResult(texts, apiKey, configuration, apocConfig, urlAccessChecker,
151156
(map, text) -> {
152157
Long index = (Long) map.get("index");
@@ -156,6 +161,7 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
156161
);
157162
}
158163

164+
159165
static <T> Stream<T> getEmbeddingResult(List<String> texts, String apiKey, Map<String, Object> configuration, ApocConfig apocConfig, URLAccessChecker urlAccessChecker,
160166
BiFunction<Map, String, T> embeddingMapping, Function<String, T> nullMapping) throws JsonProcessingException, MalformedURLException {
161167
if (texts == null) {
@@ -194,19 +200,19 @@ public Stream<MapResult> completion(@Name("prompt") String prompt, @Name("api_ke
194200
"usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 }
195201
}
196202
*/
197-
if (prompt == null) {
198-
throw new RuntimeException(ERROR_NULL_INPUT);
199-
}
203+
boolean failOnError = isFailOnError(configuration);
204+
if(checkBlankInput(prompt, failOnError)) return Stream.empty();
200205
return executeRequest(apiKey, configuration, "completions", "gpt-3.5-turbo-instruct", "prompt", prompt, "$", apocConfig, urlAccessChecker)
201206
.map(v -> (Map<String,Object>)v).map(MapResult::new);
202207
}
203208

204209
@Procedure("apoc.ml.openai.chat")
205210
@Description("apoc.ml.openai.chat(messages, api_key, configuration]) - prompts the completion API")
206211
public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Object>> messages, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
207-
if (messages == null) {
208-
throw new RuntimeException(ERROR_NULL_INPUT);
209-
}
212+
boolean failOnError = isFailOnError(configuration);
213+
if (checkNullInput(messages, failOnError)) return Stream.empty();
214+
messages = messages.stream().filter(MapUtils::isNotEmpty).toList();
215+
if (checkEmptyInput(messages, failOnError)) return Stream.empty();
210216
configuration.putIfAbsent("model", GPT_4O_MODEL);
211217
return executeRequest(apiKey, configuration, "chat/completions", (String) configuration.get("model"), "messages", messages, "$", apocConfig, urlAccessChecker)
212218
.map(v -> (Map<String,Object>)v).map(MapResult::new);
@@ -220,4 +226,32 @@ public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Objec
220226
} ] }
221227
*/
222228
}
229+
230+
private static boolean isFailOnError(Map<String, Object> configuration) {
231+
return Util.toBoolean(configuration.getOrDefault(FAIL_ON_ERROR_CONF, true));
232+
}
233+
234+
static boolean checkNullInput(Object input, boolean failOnError) {
235+
return checkInput(failOnError, () -> Objects.isNull(input));
236+
}
237+
238+
static boolean checkEmptyInput(Collection<?> input, boolean failOnError) {
239+
return checkInput(failOnError, () -> input.isEmpty());
240+
}
241+
242+
static boolean checkBlankInput(String input, boolean failOnError) {
243+
return checkInput(failOnError, () -> StringUtils.isBlank(input));
244+
}
245+
246+
private static boolean checkInput(
247+
boolean failOnError,
248+
Supplier<Boolean> checkFunction
249+
){
250+
if (checkFunction.get()) {
251+
if(failOnError) throw new RuntimeException(ERROR_NULL_INPUT);
252+
return true;
253+
}
254+
return false;
255+
}
256+
223257
}
Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,15 @@
11
package apoc.ml;
22

3+
import apoc.util.ExtendedTestUtil;
34
import org.neo4j.graphdb.GraphDatabaseService;
45

56
import java.util.Map;
67

7-
import static apoc.ml.MLUtil.ERROR_NULL_INPUT;
8-
import static apoc.util.TestUtil.testCall;
9-
import static org.junit.Assert.assertTrue;
10-
import static org.junit.Assert.fail;
8+
import static apoc.ml.MLUtil.*;
119

1210
public class MLTestUtil {
11+
1312
public static void assertNullInputFails(GraphDatabaseService db, String query, Map<String, Object> params) {
14-
try {
15-
testCall(db, query, params,
16-
(row) -> fail("Should fail due to null input")
17-
);
18-
} catch (RuntimeException e) {
19-
String message = e.getMessage();
20-
assertTrue("Current error message is: " + message,
21-
message.contains(ERROR_NULL_INPUT)
22-
);
23-
}
13+
ExtendedTestUtil.assertFails(db, query, params, ERROR_NULL_INPUT);
2414
}
2515
}

extended/src/test/java/apoc/ml/OpenAIIT.java

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import static apoc.ml.MLTestUtil.assertNullInputFails;
1717
import static apoc.ml.MLUtil.MODEL_CONF_KEY;
1818
import static apoc.ml.OpenAI.GPT_4O_MODEL;
19+
import static apoc.ml.OpenAI.FAIL_ON_ERROR_CONF;
1920
import static apoc.ml.OpenAITestResultUtils.*;
2021
import static apoc.util.TestUtil.testCall;
2122
import static apoc.util.TestUtil.testResult;
@@ -140,13 +141,49 @@ public void embeddingsNull() {
140141
);
141142
}
142143

144+
@Test
145+
public void chatNull() {
146+
assertNullInputFails(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)",
147+
Map.of("apiKey", openaiKey, "conf", emptyMap())
148+
);
149+
}
150+
151+
@Test
152+
public void chatReturnsEmptyIfFailOnErrorFalse() {
153+
TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)",
154+
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))
155+
);
156+
}
157+
158+
@Test
159+
public void embeddingsReturnsEmptyIfFailOnErrorFalse() {
160+
TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.embedding(null, $apiKey, $conf)",
161+
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))
162+
);
163+
}
164+
165+
166+
@Test
167+
public void chatWithEmptyFails() {
168+
assertNullInputFails(db, "CALL apoc.ml.openai.chat([], $apiKey, $conf)",
169+
Map.of("apiKey", openaiKey, "conf", emptyMap())
170+
);
171+
}
172+
173+
@Test
174+
public void embeddingsWithEmptyReturnsEmptyIfFailOnErrorFalse() {
175+
TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.embedding([], $apiKey, $conf)",
176+
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))
177+
);
178+
}
179+
143180
@Test
144181
public void completionNull() {
145182
assertNullInputFails(db, "CALL apoc.ml.openai.completion(null, $apiKey, $conf)",
146183
Map.of("apiKey", openaiKey, "conf", emptyMap())
147184
);
148185
}
149-
186+
150187
@Test
151188
public void chatCompletionNull() {
152189
assertNullInputFails(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)",
@@ -160,4 +197,11 @@ public void chatCompletionNullGpt35Turbo() {
160197
Map.of("apiKey", openaiKey, "conf", Map.of(MODEL_CONF_KEY, GPT_35_MODEL))
161198
);
162199
}
200+
201+
@Test
202+
public void completionReturnsEmptyIfFailOnErrorFalse() {
203+
TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.completion(null, $apiKey, $conf)",
204+
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))
205+
);
206+
}
163207
}

0 commit comments

Comments
 (0)