Skip to content

Commit 60d8363

Browse files
authored
Merge pull request #2 from Dannyj1/embedding
Implemented Support for the Embeddings Endpoint
2 parents b1264a8 + 2f00b3e commit 60d8363

20 files changed

+459
-67
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on:
44
branches:
55
- master
66
pull_request:
7-
types: [opened, synchronize, reopened]
7+
types: [ opened, synchronize, reopened ]
88
jobs:
99
build:
1010
name: Build and analyze

README.md

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@ Currently supports all chat completion models. At the time of writing these are:
77
- mistral-tiny
88
- mistral-small
99
- mistral-medium
10+
- mistral-embed
1011

11-
The embedding endpoint will be supported at a later date.
12+
New models or models not listed here may be already supported without any updates to the library.
1213

13-
**NOTE:** This library is currently in **alpha**. It is currently NOT possible to using streaming in message completions
14-
or to use embedding models. These features will be added in the future. The currently supported APIs should be stable
14+
**NOTE:** This library is currently in **alpha**. It is currently NOT possible to using streaming in message
15+
completions. This will be added in the future. The currently supported APIs should be stable
1516
however.
1617

1718
# Supported APIs
@@ -20,11 +21,13 @@ Mistral-java-client is built against version 0.0.1 of the [Mistral AI API](https
2021

2122
- [Create Chat Completions](https://docs.mistral.ai/api/#operation/createChatCompletion)
2223
- [List Available Models](https://docs.mistral.ai/api/#operation/listModels)
23-
- "Create Embeddings" to be implemented later
24+
- [Create Embeddings](https://docs.mistral.ai/guides/embeddings/)
2425

2526
# Requirements
27+
2628
- Java 17 or higher
27-
- A Mistral AI API Key (see the [Mistral documentation](https://docs.mistral.ai/#api-access) for more details on API access)
29+
- A Mistral AI API Key (see the [Mistral documentation](https://docs.mistral.ai/#api-access) for more details on API
30+
access)
2831

2932
# Installation
3033

@@ -72,12 +75,15 @@ String apiKey = "API_KEY_HERE";
7275
MistralClient client = new MistralClient(apiKey);
7376

7477
// Get a list of available models
75-
List<Model> models = client.listModels().getModels();
78+
List<Model> models = client.listModels().getModels();
7679

7780
// Loop through all available models and print their ID. The id can be used to specify the model when creating chat completions
78-
for (Model model : models) {
79-
System.out.println(model.getId());
80-
}
81+
for(
82+
Model model :models){
83+
System.out.
84+
85+
println(model.getId());
86+
}
8187
```
8288

8389
Example output:
@@ -137,10 +143,41 @@ public class HelloWorld {
137143
'''
138144
```
139145

146+
## Embeddings
147+
148+
```java
149+
// You can also put the API key in an environment variable called MISTRAL_API_KEY and remove the apiKey parameter given to the MistralClient constructor
150+
String apiKey = "API_KEY_HERE";
151+
152+
// Initialize the client. This should ideally only be done once. The instance should be re-used for multiple requests
153+
MistralClient client = new MistralClient(apiKey);
154+
List<String> exampleTexts = List.of(
155+
"This is a test sentence.",
156+
"This is another test sentence."
157+
);
158+
159+
EmbeddingRequest embeddingRequest = EmbeddingRequest.builder()
160+
.model("mistral-embed") // mistral-embed is currently the only model available for embedding
161+
.input(exampleTexts)
162+
.build();
163+
164+
EmbeddingResponse embeddingsResponse = client.createEmbedding(embeddingRequest);
165+
// Embeddings are returned as a list of FloatEmbedding objects. FloatEmbedding objects contain a list of floats per input string.
166+
// See the Mistral documentation for more information: https://docs.mistral.ai/guides/embeddings/
167+
List<FloatEmbedding> embeddings = embeddingsResponse.getData();
168+
embeddings.forEach(embedding -> System.out.println(embedding.getEmbedding()));
169+
```
170+
171+
Example output:
172+
173+
```
174+
[-0.028015137, 0.02532959, 0.042785645, ... , -0.020980835, 0.011947632, -0.0035934448]
175+
[-0.02015686, 0.04272461, 0.05529785, ... , -0.006855011, 0.009529114, -0.016448975]
176+
```
177+
140178
# Roadmap
141179

142180
- [ ] Add support for streaming in message completions
143-
- [ ] Add support for embedding models
144181
- [ ] Figure out how Mistral handles rate limiting and create a queue system to handle it
145182
- [ ] Unit tests
146183

src/main/java/nl/dannyj/mistral/MistralClient.java

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@
1717
package nl.dannyj.mistral;
1818

1919
import com.fasterxml.jackson.databind.ObjectMapper;
20+
import jakarta.validation.ConstraintViolationException;
2021
import lombok.Getter;
2122
import lombok.NonNull;
2223
import lombok.Setter;
24+
import nl.dannyj.mistral.exceptions.UnexpectedResponseException;
2325
import nl.dannyj.mistral.interceptors.MistralHeaderInterceptor;
24-
import nl.dannyj.mistral.models.request.ChatCompletionRequest;
25-
import nl.dannyj.mistral.models.response.ChatCompletionResponse;
26-
import nl.dannyj.mistral.models.response.ListModelsResponse;
26+
import nl.dannyj.mistral.models.completion.ChatCompletionRequest;
27+
import nl.dannyj.mistral.models.completion.ChatCompletionResponse;
28+
import nl.dannyj.mistral.models.embedding.EmbeddingRequest;
29+
import nl.dannyj.mistral.models.embedding.EmbeddingResponse;
30+
import nl.dannyj.mistral.models.model.ListModelsResponse;
2731
import nl.dannyj.mistral.services.HttpService;
2832
import nl.dannyj.mistral.services.MistralService;
2933
import okhttp3.OkHttpClient;
@@ -50,6 +54,7 @@ public class MistralClient {
5054

5155
/**
5256
* Constructor that initializes the MistralClient with a provided API key.
57+
*
5358
* @param apiKey The API key to be used for the Mistral AI API
5459
*/
5560
public MistralClient(@NonNull String apiKey) {
@@ -71,8 +76,9 @@ public MistralClient() {
7176

7277
/**
7378
* Constructor that initializes the MistralClient with a provided API key, HTTP client, and object mapper.
74-
* @param apiKey The API key to be used for the Mistral AI API
75-
* @param httpClient The OkHttpClient to be used for making requests to the Mistral AI API
79+
*
80+
* @param apiKey The API key to be used for the Mistral AI API
81+
* @param httpClient The OkHttpClient to be used for making requests to the Mistral AI API
7682
* @param objectMapper The Jackson ObjectMapper to be used for serializing and deserializing JSON
7783
*/
7884
public MistralClient(@NonNull String apiKey, @NonNull OkHttpClient httpClient, @NonNull ObjectMapper objectMapper) {
@@ -84,7 +90,8 @@ public MistralClient(@NonNull String apiKey, @NonNull OkHttpClient httpClient, @
8490

8591
/**
8692
* Constructor that initializes the MistralClient with a provided API key and HTTP client.
87-
* @param apiKey The API key to be used for the Mistral AI API
93+
*
94+
* @param apiKey The API key to be used for the Mistral AI API
8895
* @param httpClient The OkHttpClient to be used for making requests to the Mistral AI API
8996
*/
9097
public MistralClient(@NonNull String apiKey, @NonNull OkHttpClient httpClient) {
@@ -96,7 +103,8 @@ public MistralClient(@NonNull String apiKey, @NonNull OkHttpClient httpClient) {
96103

97104
/**
98105
* Constructor that initializes the MistralClient with a provided API key and object mapper.
99-
* @param apiKey The API key to be used for the Mistral AI API
106+
*
107+
* @param apiKey The API key to be used for the Mistral AI API
100108
* @param objectMapper The Jackson ObjectMapper to be used for serializing and deserializing JSON
101109
*/
102110
public MistralClient(@NonNull String apiKey, @NonNull ObjectMapper objectMapper) {
@@ -109,8 +117,12 @@ public MistralClient(@NonNull String apiKey, @NonNull ObjectMapper objectMapper)
109117
/**
110118
* Use the Mistral AI API to create a chat completion (an assistant reply to the conversation).
111119
* This is a blocking method.
120+
*
112121
* @param request The request to create a chat completion. See {@link ChatCompletionRequest}.
113122
* @return The response from the Mistral AI API containing the generated message. See {@link ChatCompletionResponse}.
123+
* @throws ConstraintViolationException if the request does not pass validation
124+
* @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API
125+
* @throws IllegalArgumentException if the first message role is not 'user' or 'system'
114126
*/
115127
public ChatCompletionResponse createChatCompletion(@NonNull ChatCompletionRequest request) {
116128
return mistralService.createChatCompletion(request);
@@ -119,23 +131,70 @@ public ChatCompletionResponse createChatCompletion(@NonNull ChatCompletionReques
119131
/**
120132
* Use the Mistral AI API to create a chat completion (an assistant reply to the conversation).
121133
* This is a non-blocking/asynchronous method.
134+
*
122135
* @param request The request to create a chat completion. See {@link ChatCompletionRequest}.
123136
* @return A CompletableFuture that will complete with generated message from the Mistral AI API. See {@link ChatCompletionResponse}.
137+
* @throws ConstraintViolationException if the request does not pass validation
138+
* @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API
139+
* @throws IllegalArgumentException if the first message role is not 'user' or 'system'
124140
*/
125141
public CompletableFuture<ChatCompletionResponse> createChatCompletionAsync(@NonNull ChatCompletionRequest request) {
126142
return mistralService.createChatCompletionAsync(request);
127143
}
128144

145+
/**
146+
* This method is used to create an embedding using the Mistral AI API.
147+
* The embeddings for the input strings. See the <a href="https://docs.mistral.ai/guides/embeddings/">mistral documentation</a> for more details on embeddings.
148+
* This is a blocking method.
149+
*
150+
* @param request The request to create an embedding. See {@link EmbeddingRequest}.
151+
* @return The response from the Mistral AI API containing the generated embedding. See {@link EmbeddingResponse}.
152+
* @throws ConstraintViolationException if the request does not pass validation
153+
* @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API
154+
*/
155+
public EmbeddingResponse createEmbedding(@NonNull EmbeddingRequest request) {
156+
return mistralService.createEmbedding(request);
157+
}
158+
159+
/**
160+
* This method is used to create an embedding using the Mistral AI API.
161+
* The embeddings for the input strings. See the <a href="https://docs.mistral.ai/guides/embeddings/">mistral documentation</a> for more details on embeddings.
162+
* This is a non-blocking/asynchronous method.
163+
*
164+
* @param request The request to create an embedding. See {@link EmbeddingRequest}.
165+
* @return A CompletableFuture that will complete with the generated embedding from the Mistral AI API. See {@link EmbeddingResponse}.
166+
* @throws ConstraintViolationException if the request does not pass validation
167+
* @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API
168+
*/
169+
public CompletableFuture<EmbeddingResponse> createEmbeddingAsync(@NonNull EmbeddingRequest request) {
170+
return mistralService.createEmbeddingAsync(request);
171+
}
172+
129173
/**
130174
* Lists all models available according to the Mistral AI API.
175+
* This is a blocking method.
176+
*
131177
* @return The response from the Mistral AI API containing the list of models. See {@link ListModelsResponse}.
178+
* @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API
132179
*/
133180
public ListModelsResponse listModels() {
134181
return mistralService.listModels();
135182
}
136183

184+
/**
185+
* Lists all models available according to the Mistral AI API.
186+
* This is a non-blocking/asynchronous method.
187+
*
188+
* @return A CompletableFuture that will complete with the list of models from the Mistral AI API. See {@link ListModelsResponse}.
189+
* @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API
190+
*/
191+
public CompletableFuture<ListModelsResponse> listModelsAsync() {
192+
return mistralService.listModelsAsync();
193+
}
194+
137195
/**
138196
* Builds the MistralService.
197+
*
139198
* @return A new instance of MistralService
140199
*/
141200
private MistralService buildMistralService() {
@@ -144,6 +203,7 @@ private MistralService buildMistralService() {
144203

145204
/**
146205
* Builds the HTTP client.
206+
*
147207
* @return A new instance of OkHttpClient
148208
*/
149209
private OkHttpClient buildHttpClient() {
@@ -157,6 +217,7 @@ private OkHttpClient buildHttpClient() {
157217

158218
/**
159219
* Builds the object mapper.
220+
*
160221
* @return A new instance of ObjectMapper
161222
*/
162223
private ObjectMapper buildObjectMapper() {

src/main/java/nl/dannyj/mistral/builders/MessageListBuilder.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
package nl.dannyj.mistral.builders;
1818

19-
import nl.dannyj.mistral.models.Message;
20-
import nl.dannyj.mistral.models.MessageRole;
19+
import nl.dannyj.mistral.models.completion.Message;
20+
import nl.dannyj.mistral.models.completion.MessageRole;
2121

2222
import java.util.ArrayList;
2323
import java.util.List;
@@ -40,6 +40,7 @@ public MessageListBuilder() {
4040

4141
/**
4242
* Constructor that initializes the list of Message objects with a provided list.
43+
*
4344
* @param messages The initial list of Message objects
4445
*/
4546
public MessageListBuilder(List<Message> messages) {
@@ -48,6 +49,7 @@ public MessageListBuilder(List<Message> messages) {
4849

4950
/**
5051
* Adds a message with the system role to the list with the provided content.
52+
*
5153
* @param content The content of the system message
5254
* @return The builder instance
5355
*/
@@ -58,6 +60,7 @@ public MessageListBuilder system(String content) {
5860

5961
/**
6062
* Adds a message with the assistant role to the list with the provided content.
63+
*
6164
* @param content The content of the assistant message
6265
* @return The builder instance
6366
*/
@@ -68,6 +71,7 @@ public MessageListBuilder assistant(String content) {
6871

6972
/**
7073
* Adds a message with the user role to the list with the provided content.
74+
*
7175
* @param content The content of the user message
7276
* @return The builder instance
7377
*/
@@ -78,6 +82,7 @@ public MessageListBuilder user(String content) {
7882

7983
/**
8084
* Adds a custom Message object to the list.
85+
*
8186
* @param message The Message object to be added
8287
* @return The builder instance
8388
*/
@@ -88,6 +93,7 @@ public MessageListBuilder message(Message message) {
8893

8994
/**
9095
* Returns the list of Message objects that have been added.
96+
*
9197
* @return The list of Message objects
9298
*/
9399
public List<Message> build() {
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/*
2+
* Copyright 2024 Danny Jelsma
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package nl.dannyj.mistral.models;
18+
19+
public interface Request {
20+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/*
2+
* Copyright 2024 Danny Jelsma
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package nl.dannyj.mistral.models;
18+
19+
public interface Response {
20+
}

src/main/java/nl/dannyj/mistral/models/request/ChatCompletionRequest.java renamed to src/main/java/nl/dannyj/mistral/models/completion/ChatCompletionRequest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
* limitations under the License.
1515
*/
1616

17-
package nl.dannyj.mistral.models.request;
17+
package nl.dannyj.mistral.models.completion;
1818

1919
import com.fasterxml.jackson.annotation.JsonProperty;
2020
import jakarta.validation.constraints.*;
2121
import lombok.AllArgsConstructor;
2222
import lombok.Builder;
2323
import lombok.Data;
2424
import lombok.NoArgsConstructor;
25-
import nl.dannyj.mistral.models.Message;
25+
import nl.dannyj.mistral.models.Request;
2626

2727
import java.util.List;
2828

@@ -34,7 +34,7 @@
3434
@AllArgsConstructor
3535
@NoArgsConstructor
3636
@Builder
37-
public class ChatCompletionRequest {
37+
public class ChatCompletionRequest implements Request {
3838

3939
/**
4040
* ID of the model to use. You can use the List Available Models API to see all of your available models.

0 commit comments

Comments
 (0)