66import static org .mockito .Mockito .mock ;
77import static org .mockito .Mockito .when ;
88
9+ import com .fasterxml .jackson .databind .ObjectMapper ;
910import com .sap .ai .sdk .foundationmodels .openai .OpenAiClient ;
1011import com .sap .ai .sdk .foundationmodels .openai .generated .model .EmbeddingsCreate200Response ;
11- import com .sap .ai .sdk .foundationmodels .openai .generated .model .EmbeddingsCreate200ResponseDataInner ;
12- import com .sap .ai .sdk .foundationmodels .openai .generated .model .EmbeddingsCreate200ResponseUsage ;
1312import com .sap .ai .sdk .foundationmodels .openai .generated .model .EmbeddingsCreateRequest ;
1413import com .sap .ai .sdk .foundationmodels .openai .generated .model .EmbeddingsCreateRequestInput ;
1514import java .util .List ;
1615import java .util .function .Consumer ;
16+ import lombok .SneakyThrows ;
1717import lombok .val ;
1818import org .junit .jupiter .api .BeforeEach ;
1919import org .junit .jupiter .api .DisplayName ;
2020import org .junit .jupiter .api .Test ;
21- import org .springframework .ai .chat .metadata .DefaultUsage ;
2221import org .springframework .ai .document .Document ;
23- import org .springframework .ai .embedding .Embedding ;
2422import org .springframework .ai .embedding .EmbeddingOptionsBuilder ;
2523import org .springframework .ai .embedding .EmbeddingRequest ;
26- import org .springframework .ai .embedding .EmbeddingResponse ;
27- import org .springframework .ai .embedding .EmbeddingResponseMetadata ;
2824
2925class EmbeddingModelTest {
3026 private OpenAiClient client ;
@@ -34,18 +30,19 @@ void setUp() {
3430 client = mock (OpenAiClient .class );
3531 }
3632
33+ @ SneakyThrows
3734 @ Test
3835 @ DisplayName ("Call with embedding request containing valid options" )
3936 void testCallWithValidEmbeddingRequest () {
4037 val texts = List .of ("Some text" );
4138 val springAiRequest =
4239 new EmbeddingRequest (texts , EmbeddingOptionsBuilder .builder ().withDimensions (128 ).build ());
4340
44- val vector = new float [] {0.0f };
4541 val expectedOpenAiResponse =
46- new EmbeddingsCreate200Response ()
47- .data (List .of (new EmbeddingsCreate200ResponseDataInner ().embedding (vector )))
48- .usage (new EmbeddingsCreate200ResponseUsage ().promptTokens (0 ).totalTokens (0 ));
42+ new ObjectMapper ()
43+ .readValue (
44+ getClass ().getClassLoader ().getResource ("__files/embeddingResponse.json" ),
45+ EmbeddingsCreate200Response .class );
4946
5047 val expectedOpenAiRequest =
5148 new EmbeddingsCreateRequest ()
@@ -57,15 +54,11 @@ void testCallWithValidEmbeddingRequest() {
5754
5855 val actualSpringAiResponse = new OpenAiSpringEmbeddingModel (client ).call (springAiRequest );
5956
60- val modelName = "" ; // defined by client object and options not honoured
61- val expectedSpringAiResponse =
62- new EmbeddingResponse (
63- List .of (new Embedding (vector , 0 )),
64- new EmbeddingResponseMetadata (modelName , new DefaultUsage (0 , null , 0 )));
65-
66- assertThat (expectedSpringAiResponse )
67- .usingRecursiveComparison ()
68- .isEqualTo (actualSpringAiResponse );
57+ assertThat (actualSpringAiResponse ).isNotNull ();
58+ assertThat (actualSpringAiResponse .getResult ().getOutput ())
59+ .isEqualTo (new float [] {0.0f , 3.4028235E38f , 1.4E-45f , 1.23f , -4.56f });
60+ assertThat (actualSpringAiResponse .getMetadata ().getUsage ().getPromptTokens ()).isEqualTo (2 );
61+ assertThat (actualSpringAiResponse .getMetadata ().getUsage ().getTotalTokens ()).isEqualTo (2 );
6962 }
7063
7164 @ Test
@@ -83,27 +76,28 @@ void testCallWithModelOptionSetThrows() {
8376 "Do not set a model in EmbeddingOptions, as the OpenAiClient already defines the model." );
8477 }
8578
79+ @ SneakyThrows
8680 @ Test
8781 @ DisplayName ("Embed document with text content" )
8882 void testEmbedDocument () {
8983 Document document = new Document ("Some content" );
9084
91- val vector = new float [] { 1 , 2 , 3 };
92- val openAiResponse =
93- new EmbeddingsCreate200Response ()
94- . data ( List . of ( new EmbeddingsCreate200ResponseDataInner (). embedding ( vector )))
95- . usage ( new EmbeddingsCreate200ResponseUsage (). promptTokens ( 0 ). totalTokens ( 0 ) );
85+ val expectedOpenAiResponse =
86+ new ObjectMapper ()
87+ . readValue (
88+ getClass (). getClassLoader (). getResource ( "__files/embeddingResponse.json" ),
89+ EmbeddingsCreate200Response . class );
9690
9791 val expectedOpenAiRequest =
9892 new EmbeddingsCreateRequest ()
9993 .input (EmbeddingsCreateRequestInput .create (List .of (document .getFormattedContent ())));
10094
10195 when (client .embedding (assertArg (assertRecursiveEquals (expectedOpenAiRequest ))))
102- .thenReturn (openAiResponse );
96+ .thenReturn (expectedOpenAiResponse );
10397
10498 float [] result = new OpenAiSpringEmbeddingModel (client ).embed (document );
10599
106- assertThat (result ).isEqualTo (new float [] {1 , 2 , 3 });
100+ assertThat (result ).isEqualTo (new float [] {0.0f , 3.4028235E38f , 1.4E-45f , 1.23f , - 4.56f });
107101 }
108102
109103 private static <T > Consumer <T > assertRecursiveEquals (T expected ) {
0 commit comments