Skip to content

Commit b58121f

Browse files
authored
Fixes #4408: Update LLM/Embedding models to the recent ones (#4426)
* update llm/embedding models * fix tests * update models * updated tests * fix test param
1 parent bdabadb commit b58121f

24 files changed

+494
-310
lines changed

extended/src/main/java/apoc/ml/OpenAI.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public class OpenAI {
3535
public static final String APIKEY_CONF_KEY = "apiKey";
3636
public static final String JSON_PATH_CONF_KEY = "jsonPath";
3737
public static final String PATH_CONF_KEY = "path";
38-
public static final String GPT_4O_MODEL = "gpt-4o";
38+
public static final String GPT_DEFAULT_CHAT_MODEL = "gpt-4.1";
3939
public static final String FAIL_ON_ERROR_CONF = "failOnError";
4040
public static final String ENABLE_BACK_OFF_RETRIES_CONF_KEY = "enableBackOffRetries";
4141
public static final String ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY = "exponentialBackoff";
@@ -73,7 +73,7 @@ static Stream<Object> executeRequest(String apiKey, Map<String, Object> configur
7373
apocConfig.getString(APOC_ML_OPENAI_TYPE, OpenAIRequestHandler.Type.OPENAI.name())
7474
);
7575
OpenAIRequestHandler.Type type = OpenAIRequestHandler.Type.valueOf(apiTypeString.toUpperCase(Locale.ENGLISH));
76-
76+
7777
var configForPayload = new HashMap<>(configuration);
7878
// we remove these keys from configPayload, since the json payload is calculated starting from the configPayload map
7979
Stream.of(ENDPOINT_CONF_KEY, API_TYPE_CONF_KEY, API_VERSION_CONF_KEY, APIKEY_CONF_KEY).forEach(configForPayload::remove);
@@ -90,7 +90,7 @@ static Stream<Object> executeRequest(String apiKey, Map<String, Object> configur
9090
apiType.addApiKey(headers, apiKey);
9191

9292
String payload = JsonUtil.OBJECT_MAPPER.writeValueAsString(configForPayload);
93-
93+
9494
// new URL(endpoint), path) can produce a wrong path, since endpoint can have for example embedding,
9595
// eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model
9696
// therefore is better to join the not-empty path pieces
@@ -133,7 +133,7 @@ private static void handleAPIProvider(OpenAIRequestHandler.Type type,
133133
} else {
134134
configuration.putIfAbsent(PATH_CONF_KEY, "messages");
135135
configForPayload.putIfAbsent(MAX_TOKENS, 1000);
136-
configForPayload.putIfAbsent(MODEL_CONF_KEY, "claude-3-5-sonnet-20240620");
136+
configForPayload.putIfAbsent(MODEL_CONF_KEY, "claude-3-7-sonnet-latest");
137137
}
138138

139139
configForPayload.remove(ANTHROPIC_VERSION);
@@ -170,7 +170,7 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
170170
(map, text) -> {
171171
Long index = (Long) map.get("index");
172172
return new EmbeddingResult(index, text, (List<Double>) map.get("embedding"));
173-
},
173+
},
174174
m -> new EmbeddingResult(-1, m, List.of())
175175
);
176176
}
@@ -181,7 +181,7 @@ static <T> Stream<T> getEmbeddingResult(List<String> texts, String apiKey, Map<S
181181
if (texts == null) {
182182
throw new RuntimeException(ERROR_NULL_INPUT);
183183
}
184-
184+
185185
Map<Boolean, List<String>> collect = texts.stream()
186186
.collect(Collectors.groupingBy(Objects::nonNull));
187187

@@ -227,7 +227,7 @@ public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Objec
227227
if (checkNullInput(messages, failOnError)) return Stream.empty();
228228
messages = messages.stream().filter(ExtendedMapUtils::isNotEmpty).toList();
229229
if (checkEmptyInput(messages, failOnError)) return Stream.empty();
230-
return executeRequest(apiKey, configuration, "chat/completions", GPT_4O_MODEL, "messages", messages, "$", apocConfig, urlAccessChecker)
230+
return executeRequest(apiKey, configuration, "chat/completions", GPT_DEFAULT_CHAT_MODEL, "messages", messages, "$", apocConfig, urlAccessChecker)
231231
.map(v -> (Map<String,Object>)v).map(MapResult::new);
232232
// https://platform.openai.com/docs/api-reference/chat/create
233233
/*

extended/src/main/java/apoc/ml/VertexAI.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
109109
List<String> nonNullTexts = collect.get(true);
110110

111111
Object inputs = texts.stream().map(text -> Map.of("content", text)).toList();
112-
Stream<Object> resultStream = executeRequest(accessToken, project, configuration, "textembedding-gecko", inputs, List.of(), urlAccessChecker);
112+
Stream<Object> resultStream = executeRequest(accessToken, project, configuration, "text-embedding-004", inputs, List.of(), urlAccessChecker);
113113
AtomicInteger ai = new AtomicInteger();
114114
Stream<EmbeddingResult> embeddingResultStream = resultStream
115115
.flatMap(v -> ((List<Map<String, Object>>) v).stream())
@@ -223,7 +223,7 @@ public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Strin
223223
}
224224
Object inputs = List.of(Map.of("context",context, "examples",examples, "messages", messages));
225225
var parameterKeys = List.of("temperature", "topK", "topP", "maxOutputTokens");
226-
return executeRequest(accessToken, project, configuration, "chat-bison", inputs, parameterKeys, urlAccessChecker)
226+
return executeRequest(accessToken, project, configuration, "gemini-2.5-pro", inputs, parameterKeys, urlAccessChecker)
227227
.flatMap(v -> ((List<Map<String, Object>>) v).stream())
228228
.map(v -> (Map<String, Object>) v).map(MapResult::new);
229229
// POST https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/chat-bison:predict
@@ -311,7 +311,7 @@ public Stream<MapResult> stream(@Name("messages") List<Map<String, String>> cont
311311
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
312312
var parameterKeys = List.of("temperature", "topK", "topP", "maxOutputTokens");
313313

314-
return executeRequest(accessToken, project, configuration, "gemini-pro", contents, parameterKeys, urlAccessChecker, VertexAIHandler.Type.STREAM)
314+
return executeRequest(accessToken, project, configuration, "gemini-2.5-pro", contents, parameterKeys, urlAccessChecker, VertexAIHandler.Type.STREAM)
315315
.flatMap(v -> ((List<Map<String, Object>>) v).stream())
316316
.map(MapResult::new);
317317
}
@@ -322,7 +322,7 @@ public Stream<ObjectResult> custom(@Name(value = "body") Map<String, Object> bod
322322
@Name("accessToken") String accessToken,
323323
@Name("project") String project,
324324
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
325-
return executeRequest(accessToken, project, configuration, "gemini-pro", body, Collections.emptyList(), urlAccessChecker, VertexAIHandler.Type.CUSTOM)
325+
return executeRequest(accessToken, project, configuration, "gemini-2.5-pro", body, Collections.emptyList(), urlAccessChecker, VertexAIHandler.Type.CUSTOM)
326326
.map(ObjectResult::new);
327327
}
328328
}

extended/src/main/java/apoc/ml/aws/Bedrock.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public class Bedrock {
3131
// public for testing purpose
3232
public static final String JURASSIC_2_ULTRA = "ai21.j2-ultra-v1";
3333
public static final String TITAN_EMBED_TEXT = "amazon.titan-embed-text-v1";
34-
public static final String ANTHROPIC_CLAUDE_V2 = "anthropic.claude-v2";
34+
public static final String ANTHROPIC_CLAUDE_V2 = "anthropic.claude-3";
3535
public static final String STABILITY_STABLE_DIFFUSION_XL = "stability.stable-diffusion-xl-v0";
3636

3737
@Procedure("apoc.ml.bedrock.list")

extended/src/test/java/apoc/ml/MixedbreadAIIT.java renamed to extended/src/test/java/apoc/ml/MixedbreadAiIT.java

Lines changed: 70 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
package apoc.ml;
22

33
import apoc.util.TestUtil;
4+
import apoc.util.Util;
45
import org.junit.Assume;
56
import org.junit.BeforeClass;
67
import org.junit.ClassRule;
78
import org.junit.Test;
9+
import org.junit.runner.RunWith;
10+
import org.junit.runners.Parameterized;
811
import org.neo4j.test.rule.DbmsRule;
912
import org.neo4j.test.rule.ImpermanentDbmsRule;
1013

11-
import java.util.List;
12-
import java.util.Map;
13-
import java.util.Set;
14+
import java.util.*;
1415

1516
import static apoc.ml.MLUtil.MODEL_CONF_KEY;
1617
import static apoc.ml.MixedbreadAI.*;
@@ -23,7 +24,8 @@
2324
import static org.junit.jupiter.api.Assertions.assertEquals;
2425
import static org.junit.jupiter.api.Assertions.fail;
2526

26-
public class MixedbreadAIIT {
27+
@RunWith(Parameterized.class)
28+
public class MixedbreadAiIT {
2729

2830
@ClassRule
2931
public static DbmsRule db = new ImpermanentDbmsRule();
@@ -40,6 +42,25 @@ public static void setUp() throws Exception {
4042
TestUtil.registerProcedure(db, MixedbreadAI.class);
4143
}
4244

45+
@Parameterized.Parameters(name = "chatModel: {0}")
46+
public static Collection<String[]> data() {
47+
return Arrays.asList(new String[][] {
48+
// tests with model evaluated
49+
{"mxbai-embed-2d-large-v1"},
50+
{"mixedbread-ai/mxbai-rerank-large-v1"},
51+
{"mixedbread-ai/mxbai-rerank-large-v2"},
52+
// tests with default model
53+
{null}
54+
});
55+
}
56+
57+
@Parameterized.Parameter(0)
58+
public String chatModel;
59+
60+
protected String getApiKey(){
61+
return apiKey;
62+
}
63+
4364
@Test
4465
public void getEmbedding() {
4566
testResult(db, "CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], $apiKey, $conf)",
@@ -58,7 +79,7 @@ public void getEmbedding() {
5879
@Test
5980
public void getEmbeddingWithNulls() {
6081
testResult(db, "CALL apoc.ml.mixedbread.embedding([null, 'Some Text', null, 'Another Text'], $apiKey, $conf)",
61-
Map.of("apiKey", apiKey, "conf", emptyMap()),
82+
Util.map("apiKey", apiKey, "conf", emptyMap()),
6283
(r) -> {
6384

6485
Map<String, Object> row = r.next();
@@ -129,21 +150,6 @@ public void getEmbeddingWithCustomEmbeddingSize() {
129150
});
130151
}
131152

132-
@Test
133-
public void getEmbeddingWithOtherModel() {
134-
testResult(db, "CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], $apiKey, $conf)",
135-
map("apiKey", apiKey, "conf", map(MODEL_CONF_KEY, "mxbai-embed-2d-large-v1")),
136-
r -> {
137-
Map<String, Object> row = r.next();
138-
assertEmbedding(row, 0L, "Some Text", 1024);
139-
140-
row = r.next();
141-
assertEmbedding(row, 1L, "Other Text", 1024);
142-
143-
assertFalse(r.hasNext());
144-
});
145-
}
146-
147153
@Test
148154
public void getEmbeddingWithWrongModel() {
149155
try {
@@ -161,8 +167,41 @@ public void getEmbeddingWithWrongModel() {
161167
}
162168
}
163169

170+
@Test
171+
public void customWithMissingEndpoint() {
172+
try {
173+
testCall(db, "CALL apoc.ml.mixedbread.custom($apiKey, $conf)",
174+
map("apiKey", apiKey,
175+
"conf", map(MODEL_CONF_KEY, "aModelId")
176+
),
177+
r -> fail("Should fail due to missing endpoint"));
178+
} catch (Exception e) {
179+
String errMsg = e.getMessage();
180+
assertTrue("Actual error message is: " + errMsg,
181+
errMsg.contains(ERROR_MSG_MISSING_ENDPOINT)
182+
);
183+
}
184+
}
185+
186+
@Test
187+
public void customWithMissingModel() {
188+
try {
189+
testCall(db, "CALL apoc.ml.mixedbread.custom($apiKey, $conf)",
190+
map("apiKey", apiKey,
191+
"conf", map(ENDPOINT_CONF_KEY, MIXEDBREAD_BASE_URL + "/reranking",
192+
"foo", "bar")
193+
),
194+
r -> fail("Should fail due to missing model"));
195+
} catch (Exception e) {
196+
String errMsg = e.getMessage();
197+
assertTrue("Actual error message is: " + errMsg,
198+
errMsg.contains(ERROR_MSG_MISSING_MODELID)
199+
);
200+
}
201+
}
202+
164203
/**
165-
* Example taken from here: https://www.mixedbread.ai/api-reference/endpoints/reranking
204+
* Example taken from here: https://www.mixedbread.ai/api-reference/endpoints/reranking
166205
*/
167206
@Test
168207
public void customWithReranking() {
@@ -174,21 +213,21 @@ public void customWithReranking() {
174213
"The Great Gatsby, a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan."
175214
);
176215
Map<String, Object> conf = map(ENDPOINT_CONF_KEY, MIXEDBREAD_BASE_URL + "/reranking",
177-
MODEL_CONF_KEY, "mixedbread-ai/mxbai-rerank-large-v1",
216+
MODEL_CONF_KEY, chatModel,
178217
"query", "Who is the author of To Kill a Mockingbird?",
179218
"top_k", 3,
180219
"input", input
181220
);
182221
testCall(db, "CALL apoc.ml.mixedbread.custom($apiKey, $conf)",
183-
Map.of("apiKey", apiKey, "conf", conf),
222+
Util.map("apiKey", getApiKey(), "conf", conf),
184223
row -> {
185224
Map value = (Map) row.get("value");
186-
225+
187226
List<Map> data = (List<Map>) value.get("data");
188227
assertEquals(3, data.size());
189-
190-
Map<String, Object> firstData = map("index", 0L,
191-
"score", 0.9980469,
228+
229+
Map<String, Object> firstData = map("index", 0L,
230+
"score", 0.9980469,
192231
"object", "text_document");
193232
assertEquals(firstData, data.get(0));
194233

@@ -204,45 +243,12 @@ public void customWithReranking() {
204243
"score", 0.06915283,
205244
"object", "text_document");
206245
assertEquals(thirdData, data.get(2));
207-
246+
208247
assertEquals("list", value.get("object"));
209248
});
210249
}
211250

212-
@Test
213-
public void customWithMissingEndpoint() {
214-
try {
215-
testCall(db, "CALL apoc.ml.mixedbread.custom($apiKey, $conf)",
216-
map("apiKey", apiKey,
217-
"conf", map(MODEL_CONF_KEY, "aModelId")
218-
),
219-
r -> fail("Should fail due to missing endpoint"));
220-
} catch (Exception e) {
221-
String errMsg = e.getMessage();
222-
assertTrue("Actual error message is: " + errMsg,
223-
errMsg.contains(ERROR_MSG_MISSING_ENDPOINT)
224-
);
225-
}
226-
}
227-
228-
@Test
229-
public void customWithMissingModel() {
230-
try {
231-
testCall(db, "CALL apoc.ml.mixedbread.custom($apiKey, $conf)",
232-
map("apiKey", apiKey,
233-
"conf", map(ENDPOINT_CONF_KEY, MIXEDBREAD_BASE_URL + "/reranking",
234-
"foo", "bar")
235-
),
236-
r -> fail("Should fail due to missing model"));
237-
} catch (Exception e) {
238-
String errMsg = e.getMessage();
239-
assertTrue("Actual error message is: " + errMsg,
240-
errMsg.contains(ERROR_MSG_MISSING_MODELID)
241-
);
242-
}
243-
}
244-
245-
private static void assertEmbedding(Map<String, Object> row,
251+
protected static void assertEmbedding(Map<String, Object> row,
246252
long expectedIdx,
247253
String expectedText,
248254
Integer expectedSize) {
@@ -252,8 +258,8 @@ private static void assertEmbedding(Map<String, Object> row,
252258
assertEquals(expectedSize, embedding.size());
253259
}
254260

255-
private static void assertNullEmbedding(Map<String, Object> row) {
261+
protected static void assertNullEmbedding(Map<String, Object> row) {
256262
assertEmbedding(row, -1, null, 0);
257263
}
258-
264+
259265
}

0 commit comments

Comments
 (0)