11package apoc .ml ;
22
33import apoc .util .TestUtil ;
4+ import apoc .util .Util ;
45import org .junit .Assume ;
56import org .junit .BeforeClass ;
67import org .junit .ClassRule ;
78import org .junit .Test ;
9+ import org .junit .runner .RunWith ;
10+ import org .junit .runners .Parameterized ;
811import org .neo4j .test .rule .DbmsRule ;
912import org .neo4j .test .rule .ImpermanentDbmsRule ;
1013
11- import java .util .List ;
12- import java .util .Map ;
13- import java .util .Set ;
14+ import java .util .*;
1415
1516import static apoc .ml .MLUtil .MODEL_CONF_KEY ;
1617import static apoc .ml .MixedbreadAI .*;
2324import static org .junit .jupiter .api .Assertions .assertEquals ;
2425import static org .junit .jupiter .api .Assertions .fail ;
2526
26- public class MixedbreadAIIT {
27+ @ RunWith (Parameterized .class )
28+ public class MixedbreadAiIT {
2729
2830 @ ClassRule
2931 public static DbmsRule db = new ImpermanentDbmsRule ();
@@ -40,6 +42,25 @@ public static void setUp() throws Exception {
4042 TestUtil .registerProcedure (db , MixedbreadAI .class );
4143 }
4244
45+ @ Parameterized .Parameters (name = "chatModel: {0}" )
46+ public static Collection <String []> data () {
47+ return Arrays .asList (new String [][] {
48+ // tests with model evaluated
49+ {"mxbai-embed-2d-large-v1" },
50+ {"mixedbread-ai/mxbai-rerank-large-v1" },
51+ {"mixedbread-ai/mxbai-rerank-large-v2" },
52+ // tests with default model
53+ {null }
54+ });
55+ }
56+
57+ @ Parameterized .Parameter (0 )
58+ public String chatModel ;
59+
60+ protected String getApiKey (){
61+ return apiKey ;
62+ }
63+
4364 @ Test
4465 public void getEmbedding () {
4566 testResult (db , "CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], $apiKey, $conf)" ,
@@ -58,7 +79,7 @@ public void getEmbedding() {
5879 @ Test
5980 public void getEmbeddingWithNulls () {
6081 testResult (db , "CALL apoc.ml.mixedbread.embedding([null, 'Some Text', null, 'Another Text'], $apiKey, $conf)" ,
61- Map . of ("apiKey" , apiKey , "conf" , emptyMap ()),
82+ Util . map ("apiKey" , apiKey , "conf" , emptyMap ()),
6283 (r ) -> {
6384
6485 Map <String , Object > row = r .next ();
@@ -129,21 +150,6 @@ public void getEmbeddingWithCustomEmbeddingSize() {
129150 });
130151 }
131152
132- @ Test
133- public void getEmbeddingWithOtherModel () {
134- testResult (db , "CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], $apiKey, $conf)" ,
135- map ("apiKey" , apiKey , "conf" , map (MODEL_CONF_KEY , "mxbai-embed-2d-large-v1" )),
136- r -> {
137- Map <String , Object > row = r .next ();
138- assertEmbedding (row , 0L , "Some Text" , 1024 );
139-
140- row = r .next ();
141- assertEmbedding (row , 1L , "Other Text" , 1024 );
142-
143- assertFalse (r .hasNext ());
144- });
145- }
146-
147153 @ Test
148154 public void getEmbeddingWithWrongModel () {
149155 try {
@@ -161,8 +167,41 @@ public void getEmbeddingWithWrongModel() {
161167 }
162168 }
163169
170+ @ Test
171+ public void customWithMissingEndpoint () {
172+ try {
173+ testCall (db , "CALL apoc.ml.mixedbread.custom($apiKey, $conf)" ,
174+ map ("apiKey" , apiKey ,
175+ "conf" , map (MODEL_CONF_KEY , "aModelId" )
176+ ),
177+ r -> fail ("Should fail due to missing endpoint" ));
178+ } catch (Exception e ) {
179+ String errMsg = e .getMessage ();
180+ assertTrue ("Actual error message is: " + errMsg ,
181+ errMsg .contains (ERROR_MSG_MISSING_ENDPOINT )
182+ );
183+ }
184+ }
185+
186+ @ Test
187+ public void customWithMissingModel () {
188+ try {
189+ testCall (db , "CALL apoc.ml.mixedbread.custom($apiKey, $conf)" ,
190+ map ("apiKey" , apiKey ,
191+ "conf" , map (ENDPOINT_CONF_KEY , MIXEDBREAD_BASE_URL + "/reranking" ,
192+ "foo" , "bar" )
193+ ),
194+ r -> fail ("Should fail due to missing model" ));
195+ } catch (Exception e ) {
196+ String errMsg = e .getMessage ();
197+ assertTrue ("Actual error message is: " + errMsg ,
198+ errMsg .contains (ERROR_MSG_MISSING_MODELID )
199+ );
200+ }
201+ }
202+
164203 /**
165- * Example taken from here: https://www.mixedbread.ai/api-reference/endpoints/reranking
204+ * Example taken from here: https://www.mixedbread.ai/api-reference/endpoints/reranking
166205 */
167206 @ Test
168207 public void customWithReranking () {
@@ -174,21 +213,21 @@ public void customWithReranking() {
174213 "The Great Gatsby, a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan."
175214 );
176215 Map <String , Object > conf = map (ENDPOINT_CONF_KEY , MIXEDBREAD_BASE_URL + "/reranking" ,
177- MODEL_CONF_KEY , "mixedbread-ai/mxbai-rerank-large-v1" ,
216+ MODEL_CONF_KEY , chatModel ,
178217 "query" , "Who is the author of To Kill a Mockingbird?" ,
179218 "top_k" , 3 ,
180219 "input" , input
181220 );
182221 testCall (db , "CALL apoc.ml.mixedbread.custom($apiKey, $conf)" ,
183- Map . of ("apiKey" , apiKey , "conf" , conf ),
222+ Util . map ("apiKey" , getApiKey () , "conf" , conf ),
184223 row -> {
185224 Map value = (Map ) row .get ("value" );
186-
225+
187226 List <Map > data = (List <Map >) value .get ("data" );
188227 assertEquals (3 , data .size ());
189-
190- Map <String , Object > firstData = map ("index" , 0L ,
191- "score" , 0.9980469 ,
228+
229+ Map <String , Object > firstData = map ("index" , 0L ,
230+ "score" , 0.9980469 ,
192231 "object" , "text_document" );
193232 assertEquals (firstData , data .get (0 ));
194233
@@ -204,45 +243,12 @@ public void customWithReranking() {
204243 "score" , 0.06915283 ,
205244 "object" , "text_document" );
206245 assertEquals (thirdData , data .get (2 ));
207-
246+
208247 assertEquals ("list" , value .get ("object" ));
209248 });
210249 }
211250
212- @ Test
213- public void customWithMissingEndpoint () {
214- try {
215- testCall (db , "CALL apoc.ml.mixedbread.custom($apiKey, $conf)" ,
216- map ("apiKey" , apiKey ,
217- "conf" , map (MODEL_CONF_KEY , "aModelId" )
218- ),
219- r -> fail ("Should fail due to missing endpoint" ));
220- } catch (Exception e ) {
221- String errMsg = e .getMessage ();
222- assertTrue ("Actual error message is: " + errMsg ,
223- errMsg .contains (ERROR_MSG_MISSING_ENDPOINT )
224- );
225- }
226- }
227-
228- @ Test
229- public void customWithMissingModel () {
230- try {
231- testCall (db , "CALL apoc.ml.mixedbread.custom($apiKey, $conf)" ,
232- map ("apiKey" , apiKey ,
233- "conf" , map (ENDPOINT_CONF_KEY , MIXEDBREAD_BASE_URL + "/reranking" ,
234- "foo" , "bar" )
235- ),
236- r -> fail ("Should fail due to missing model" ));
237- } catch (Exception e ) {
238- String errMsg = e .getMessage ();
239- assertTrue ("Actual error message is: " + errMsg ,
240- errMsg .contains (ERROR_MSG_MISSING_MODELID )
241- );
242- }
243- }
244-
245- private static void assertEmbedding (Map <String , Object > row ,
251+ protected static void assertEmbedding (Map <String , Object > row ,
246252 long expectedIdx ,
247253 String expectedText ,
248254 Integer expectedSize ) {
@@ -252,8 +258,8 @@ private static void assertEmbedding(Map<String, Object> row,
252258 assertEquals (expectedSize , embedding .size ());
253259 }
254260
255- private static void assertNullEmbedding (Map <String , Object > row ) {
261+ protected static void assertNullEmbedding (Map <String , Object > row ) {
256262 assertEmbedding (row , -1 , null , 0 );
257263 }
258-
264+
259265}
0 commit comments