diff --git a/src/main/java/nl/dannyj/mistral/MistralClient.java b/src/main/java/nl/dannyj/mistral/MistralClient.java index 0b7e24c..1e130b0 100644 --- a/src/main/java/nl/dannyj/mistral/MistralClient.java +++ b/src/main/java/nl/dannyj/mistral/MistralClient.java @@ -251,7 +251,7 @@ public void createChatCompletionStream(@NonNull ChatCompletionRequest request, @ * @return A new instance of MistralService */ private MistralService buildMistralService() { - return new MistralService(this, new HttpService(this)); + return new MistralService(new HttpService(this.httpClient), this.objectMapper); } /** @@ -260,7 +260,7 @@ private MistralService buildMistralService() { * @return A new instance of OkHttpClient */ private OkHttpClient buildHttpClient(int readTimeoutSeconds, int connectTimeoutSeconds, int writeTimeoutSeconds) { - MistralHeaderInterceptor mistralInterceptor = new MistralHeaderInterceptor(this); + MistralHeaderInterceptor mistralInterceptor = new MistralHeaderInterceptor(this.getApiKey()); return new OkHttpClient.Builder() .readTimeout(readTimeoutSeconds, TimeUnit.SECONDS) diff --git a/src/main/java/nl/dannyj/mistral/builders/MessageListBuilder.java b/src/main/java/nl/dannyj/mistral/builders/MessageListBuilder.java index 966ab00..625b541 100644 --- a/src/main/java/nl/dannyj/mistral/builders/MessageListBuilder.java +++ b/src/main/java/nl/dannyj/mistral/builders/MessageListBuilder.java @@ -18,7 +18,7 @@ import jakarta.annotation.Nullable; import jakarta.validation.constraints.NotEmpty; -import jakarta.validation.constraints.NotNull; +import lombok.NonNull; import nl.dannyj.mistral.models.completion.message.AssistantMessage; import nl.dannyj.mistral.models.completion.message.ChatMessage; import nl.dannyj.mistral.models.completion.message.SystemMessage; @@ -59,7 +59,7 @@ public MessageListBuilder(List messages) { * @param content The text content of the system message. Cannot be null. * @return This builder instance. */ - public MessageListBuilder system(@NotNull String content) { + public MessageListBuilder system(@NonNull String content) { this.messages.add(new SystemMessage(content)); return this; } @@ -70,7 +70,7 @@ public MessageListBuilder system(@NotNull String content) { * @param content The text content of the assistant message. Cannot be null. * @return This builder instance. */ - public MessageListBuilder assistant(@NotNull String content) { + public MessageListBuilder assistant(@NonNull String content) { this.messages.add(new AssistantMessage(content)); return this; } @@ -81,7 +81,7 @@ public MessageListBuilder assistant(@NotNull String content) { * @param toolCalls The list of tool calls. Cannot be null or empty. * @return This builder instance. */ - public MessageListBuilder assistant(@NotNull @NotEmpty List toolCalls) { + public MessageListBuilder assistant(@NonNull @NotEmpty List toolCalls) { this.messages.add(new AssistantMessage(toolCalls)); return this; } @@ -92,7 +92,7 @@ public MessageListBuilder assistant(@NotNull @NotEmpty List toolCalls) * @param content The text content of the user message. Cannot be null. * @return This builder instance. */ - public MessageListBuilder user(@NotNull String content) { + public MessageListBuilder user(@NonNull String content) { this.messages.add(new UserMessage(content)); return this; } @@ -104,7 +104,7 @@ public MessageListBuilder user(@NotNull String content) { * @param toolCallId The ID of the tool call this message responds to. Can be null. * @return This builder instance. */ - public MessageListBuilder tool(@NotNull String content, @Nullable String toolCallId) { + public MessageListBuilder tool(@NonNull String content, @Nullable String toolCallId) { this.messages.add(new ToolMessage(content, toolCallId)); return this; } @@ -117,7 +117,7 @@ public MessageListBuilder tool(@NotNull String content, @Nullable String toolCal * @param message The ChatMessage object to be added. Cannot be null. * @return This builder instance. */ - public MessageListBuilder message(@NotNull ChatMessage message) { + public MessageListBuilder message(@NonNull ChatMessage message) { this.messages.add(message); return this; } diff --git a/src/main/java/nl/dannyj/mistral/interceptors/MistralHeaderInterceptor.java b/src/main/java/nl/dannyj/mistral/interceptors/MistralHeaderInterceptor.java index 5941bed..e433285 100644 --- a/src/main/java/nl/dannyj/mistral/interceptors/MistralHeaderInterceptor.java +++ b/src/main/java/nl/dannyj/mistral/interceptors/MistralHeaderInterceptor.java @@ -17,7 +17,6 @@ package nl.dannyj.mistral.interceptors; import lombok.NonNull; -import nl.dannyj.mistral.MistralClient; import okhttp3.Interceptor; import okhttp3.Request; import okhttp3.Response; @@ -27,22 +26,22 @@ public class MistralHeaderInterceptor implements Interceptor { - private final MistralClient client; + private final String apiKey; - public MistralHeaderInterceptor(@NonNull MistralClient client) { - this.client = client; + public MistralHeaderInterceptor(@NonNull String apiKey) { + if (apiKey.isBlank()) { + throw new IllegalArgumentException("No API key provided"); + } + + this.apiKey = apiKey; } @NotNull @Override - public Response intercept(@NotNull Chain chain) throws IOException { + public Response intercept(@NonNull Chain chain) throws IOException { Request request = chain.request(); Request.Builder newRequestBuilder = request.newBuilder(); - if (client.getApiKey() == null || client.getApiKey().isBlank()) { - throw new IllegalArgumentException("No API key provided in MistralClient"); - } - if (request.header("Content-Type") == null) { newRequestBuilder.addHeader("Content-Type", "application/json"); } @@ -52,7 +51,7 @@ public Response intercept(@NotNull Chain chain) throws IOException { } if (request.header("Authorization") == null) { - newRequestBuilder.addHeader("Authorization", "Bearer " + client.getApiKey()); + newRequestBuilder.addHeader("Authorization", "Bearer " + this.apiKey); } Request newRequest = newRequestBuilder.build(); diff --git a/src/main/java/nl/dannyj/mistral/models/completion/message/AssistantMessage.java b/src/main/java/nl/dannyj/mistral/models/completion/message/AssistantMessage.java index 462fa66..64263d7 100644 --- a/src/main/java/nl/dannyj/mistral/models/completion/message/AssistantMessage.java +++ b/src/main/java/nl/dannyj/mistral/models/completion/message/AssistantMessage.java @@ -20,10 +20,10 @@ import com.fasterxml.jackson.annotation.JsonProperty; import jakarta.annotation.Nullable; import jakarta.validation.constraints.NotEmpty; -import jakarta.validation.constraints.NotNull; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.Setter; import nl.dannyj.mistral.models.completion.content.ContentChunk; import nl.dannyj.mistral.models.completion.content.TextChunk; @@ -69,7 +69,7 @@ public class AssistantMessage extends ChatMessage { * * @param textContent The text content. */ - public AssistantMessage(@NotNull String textContent) { + public AssistantMessage(@NonNull String textContent) { this.content = Collections.singletonList(new TextChunk(textContent)); this.toolCalls = null; } @@ -79,7 +79,7 @@ public AssistantMessage(@NotNull String textContent) { * * @param toolCalls The list of tool calls. Cannot be null or empty. */ - public AssistantMessage(@NotNull @NotEmpty List toolCalls) { + public AssistantMessage(@NonNull @NotEmpty List toolCalls) { this.content = null; this.toolCalls = toolCalls; } diff --git a/src/main/java/nl/dannyj/mistral/models/completion/message/ChatMessage.java b/src/main/java/nl/dannyj/mistral/models/completion/message/ChatMessage.java index 3007374..a988cb3 100644 --- a/src/main/java/nl/dannyj/mistral/models/completion/message/ChatMessage.java +++ b/src/main/java/nl/dannyj/mistral/models/completion/message/ChatMessage.java @@ -52,7 +52,6 @@ public abstract class ChatMessage { /** * The content of the message. Can be null or a list of content chunks. * - * @param content The list of content chunks, or null. * @return The list of content chunks, or null. */ @Nullable diff --git a/src/main/java/nl/dannyj/mistral/models/completion/message/ToolMessage.java b/src/main/java/nl/dannyj/mistral/models/completion/message/ToolMessage.java index 74c5622..bb89f76 100644 --- a/src/main/java/nl/dannyj/mistral/models/completion/message/ToolMessage.java +++ b/src/main/java/nl/dannyj/mistral/models/completion/message/ToolMessage.java @@ -19,10 +19,10 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import jakarta.annotation.Nullable; -import jakarta.validation.constraints.NotNull; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.Setter; import nl.dannyj.mistral.models.completion.content.ContentChunk; import nl.dannyj.mistral.models.completion.content.TextChunk; @@ -68,7 +68,7 @@ public class ToolMessage extends ChatMessage { * @param textContent The text content (result) of the tool call. Cannot be null. * @param toolCallId The ID of the tool call this message responds to. Can be null. */ - public ToolMessage(@NotNull String textContent, @Nullable String toolCallId) { + public ToolMessage(@NonNull String textContent, @Nullable String toolCallId) { this.content = Collections.singletonList(new TextChunk(textContent)); this.toolCallId = toolCallId; } @@ -79,7 +79,7 @@ public ToolMessage(@NotNull String textContent, @Nullable String toolCallId) { * @param contentChunks The list of content chunks representing the tool result. Cannot be null or empty. * @param toolCallId The ID of the tool call this message responds to. Can be null. */ - public ToolMessage(@NotNull @jakarta.validation.constraints.NotEmpty List contentChunks, @Nullable String toolCallId) { + public ToolMessage(@NonNull @jakarta.validation.constraints.NotEmpty List contentChunks, @Nullable String toolCallId) { this.content = contentChunks; this.toolCallId = toolCallId; } diff --git a/src/main/java/nl/dannyj/mistral/models/completion/message/UserMessage.java b/src/main/java/nl/dannyj/mistral/models/completion/message/UserMessage.java index 5bd7c4e..7704a5c 100644 --- a/src/main/java/nl/dannyj/mistral/models/completion/message/UserMessage.java +++ b/src/main/java/nl/dannyj/mistral/models/completion/message/UserMessage.java @@ -17,9 +17,9 @@ package nl.dannyj.mistral.models.completion.message; import jakarta.validation.constraints.NotEmpty; -import jakarta.validation.constraints.NotNull; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; +import lombok.NonNull; import nl.dannyj.mistral.models.completion.content.ContentChunk; import nl.dannyj.mistral.models.completion.content.TextChunk; @@ -39,7 +39,7 @@ public class UserMessage extends ChatMessage { * * @param textContent The text content for the user message. Cannot be null or empty. */ - public UserMessage(@NotNull String textContent) { + public UserMessage(@NonNull String textContent) { if (textContent.isEmpty()) { throw new IllegalArgumentException("User message text content cannot be empty."); } @@ -51,7 +51,7 @@ public UserMessage(@NotNull String textContent) { * * @param contentChunks The list of content chunks. Cannot be null or empty. */ - public UserMessage(@NotNull @NotEmpty List contentChunks) { + public UserMessage(@NonNull @NotEmpty List contentChunks) { this.content = contentChunks; } diff --git a/src/main/java/nl/dannyj/mistral/serialization/ContentChunkListDeserializer.java b/src/main/java/nl/dannyj/mistral/serialization/ContentChunkListDeserializer.java index 75d5b78..1375832 100644 --- a/src/main/java/nl/dannyj/mistral/serialization/ContentChunkListDeserializer.java +++ b/src/main/java/nl/dannyj/mistral/serialization/ContentChunkListDeserializer.java @@ -34,7 +34,7 @@ public class ContentChunkListDeserializer extends StdDeserializer> implements ContextualDeserializer { - private JsonDeserializer defaultDeserializer; + private transient JsonDeserializer defaultDeserializer; public ContentChunkListDeserializer() { this(null); diff --git a/src/main/java/nl/dannyj/mistral/serialization/ToolChoiceOptionDeserializer.java b/src/main/java/nl/dannyj/mistral/serialization/ToolChoiceOptionDeserializer.java index bfedb83..b0ec2b2 100644 --- a/src/main/java/nl/dannyj/mistral/serialization/ToolChoiceOptionDeserializer.java +++ b/src/main/java/nl/dannyj/mistral/serialization/ToolChoiceOptionDeserializer.java @@ -41,16 +41,8 @@ public ToolChoiceOption deserialize(JsonParser jp, DeserializationContext ctxt) if (token == JsonToken.VALUE_STRING) { String enumValue = jp.getText().toUpperCase(); - try { - return ToolChoiceEnum.valueOf(enumValue); - } catch (IllegalArgumentException e) { - if ("ANY".equalsIgnoreCase(enumValue)) return ToolChoiceEnum.ANY; - if ("AUTO".equalsIgnoreCase(enumValue)) return ToolChoiceEnum.AUTO; - if ("NONE".equalsIgnoreCase(enumValue)) return ToolChoiceEnum.NONE; - if ("REQUIRED".equalsIgnoreCase(enumValue)) return ToolChoiceEnum.REQUIRED; - - throw ctxt.weirdStringException(enumValue, ToolChoiceEnum.class, "Not a valid ToolChoiceEnum value"); - } + + return ToolChoiceEnum.valueOf(enumValue); } else if (token == JsonToken.START_OBJECT) { return mapper.readValue(jp, SpecificToolChoice.class); } diff --git a/src/main/java/nl/dannyj/mistral/services/HttpService.java b/src/main/java/nl/dannyj/mistral/services/HttpService.java index bc1735f..4cf057a 100644 --- a/src/main/java/nl/dannyj/mistral/services/HttpService.java +++ b/src/main/java/nl/dannyj/mistral/services/HttpService.java @@ -17,7 +17,6 @@ package nl.dannyj.mistral.services; import lombok.NonNull; -import nl.dannyj.mistral.MistralClient; import nl.dannyj.mistral.exceptions.MistralAPIException; import okhttp3.Callback; import okhttp3.MediaType; @@ -37,15 +36,15 @@ public class HttpService { private static final String API_URL = "https://api.mistral.ai/v1"; - private final MistralClient client; + private final OkHttpClient httpClient; /** - * Constructor that initializes the HttpService with a provided MistralClient. + * Constructor that initializes the HttpService with a provided OkHttpClient. * - * @param client The MistralClient to be used for making requests to the Mistral AI API + * @param httpClient The OkHttpClient to be used for making requests to the Mistral AI API */ - public HttpService(@NonNull MistralClient client) { - this.client = client; + public HttpService(@NonNull OkHttpClient httpClient) { + this.httpClient = httpClient; } /** @@ -91,7 +90,6 @@ public void streamPost(@NonNull String urlPath, @NonNull String body, Callback c .url(API_URL + urlPath) .post(RequestBody.create(body, MediaType.parse("application/json"))) .build(); - OkHttpClient httpClient = client.getHttpClient(); httpClient.newCall(request).enqueue(callBack); } @@ -104,8 +102,6 @@ public void streamPost(@NonNull String urlPath, @NonNull String body, Callback c * @throws MistralAPIException If the response is not successful, the response body is null or an IOException occurs in the objectmapper */ private String executeRequest(Request request) { - OkHttpClient httpClient = client.getHttpClient(); - try (Response response = httpClient.newCall(request).execute()) { if (!response.isSuccessful()) { throw new MistralAPIException("Received unexpected response code " + response.code() + ": " + (response.body() != null ? response.body().string() : response)); diff --git a/src/main/java/nl/dannyj/mistral/services/MistralService.java b/src/main/java/nl/dannyj/mistral/services/MistralService.java index b295189..a28da8e 100644 --- a/src/main/java/nl/dannyj/mistral/services/MistralService.java +++ b/src/main/java/nl/dannyj/mistral/services/MistralService.java @@ -17,13 +17,13 @@ package nl.dannyj.mistral.services; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import jakarta.validation.ConstraintViolation; import jakarta.validation.ConstraintViolationException; import jakarta.validation.Validation; import jakarta.validation.Validator; import jakarta.validation.ValidatorFactory; import lombok.NonNull; -import nl.dannyj.mistral.MistralClient; import nl.dannyj.mistral.exceptions.InvalidJsonException; import nl.dannyj.mistral.exceptions.UnexpectedResponseEndException; import nl.dannyj.mistral.exceptions.UnexpectedResponseException; @@ -41,7 +41,6 @@ import okhttp3.Call; import okhttp3.Callback; import okhttp3.ResponseBody; -import org.jetbrains.annotations.NotNull; import java.io.BufferedReader; import java.io.IOException; @@ -55,17 +54,17 @@ public class MistralService { private final HttpService httpService; - private final MistralClient client; + private final ObjectMapper objectMapper; private final Validator validator; /** - * Constructor that initializes the MistralService with a provided MistralClient and HttpService. + * Constructor that initializes the MistralService with a provided HttpService and ObjectMapper. * - * @param client The MistralClient to be used for interacting with the Mistral AI API - * @param httpService The HttpService to be used for making HTTP requests + * @param httpService The HttpService to be used for making HTTP requests to the Mistral AI API + * @param objectMapper The ObjectMapper to be used for converting objects to and from JSON */ - public MistralService(@NonNull MistralClient client, @NonNull HttpService httpService) { - this.client = client; + public MistralService(@NonNull HttpService httpService, @NonNull ObjectMapper objectMapper) { + this.objectMapper = objectMapper; this.httpService = httpService; try (ValidatorFactory validatorFactory = Validation.buildDefaultValidatorFactory()) { @@ -118,11 +117,11 @@ public void createChatCompletionStream(@NonNull ChatCompletionRequest request, @ validateRequest(request); try { - String requestJson = client.getObjectMapper().writeValueAsString(request); + String requestJson = this.objectMapper.writeValueAsString(request); httpService.streamPost("/chat/completions", requestJson, new Callback() { @Override - public void onResponse(@NotNull Call call, @NotNull okhttp3.Response response) { + public void onResponse(@NonNull Call call, @NonNull okhttp3.Response response) { if (!response.isSuccessful()) { callback.onError(new UnexpectedResponseException("Received unexpected response code " + response.code() + ": " + response)); return; @@ -141,7 +140,7 @@ public void onResponse(@NotNull Call call, @NotNull okhttp3.Response response) { } @Override - public void onFailure(@NotNull Call call, @NotNull IOException e) { + public void onFailure(@NonNull Call call, @NonNull IOException e) { callback.onError(e); } }); @@ -161,7 +160,7 @@ public ListModelsResponse listModels() { String response = httpService.get("/models"); try { - return client.getObjectMapper().readValue(response, ListModelsResponse.class); + return this.objectMapper.readValue(response, ListModelsResponse.class); } catch (JsonProcessingException e) { throw new UnexpectedResponseException("Received unexpected response from the Mistral.ai API (mistral-java-client might need to be updated): " + response, e); } @@ -238,7 +237,7 @@ private U postRequest(String endpoint, T String requestJson = null; try { - requestJson = client.getObjectMapper().writeValueAsString(request); + requestJson = this.objectMapper.writeValueAsString(request); } catch (JsonProcessingException e) { throw new InvalidJsonException("Failed to convert request to JSON", e); } @@ -246,13 +245,13 @@ private U postRequest(String endpoint, T try { response = httpService.post(endpoint, requestJson); - return client.getObjectMapper().readValue(response, responseType); + return this.objectMapper.readValue(response, responseType); } catch (JsonProcessingException e) { throw new UnexpectedResponseException("Received unexpected response from the Mistral.ai API (mistral-java-client might need to be updated): " + response, e); } } - private void handleResponseBody(@NotNull ResponseBody responseBody, ChatCompletionChunkCallback callback) throws IOException { + private void handleResponseBody(@NonNull ResponseBody responseBody, ChatCompletionChunkCallback callback) throws IOException { BufferedReader reader = new BufferedReader(responseBody.charStream()); String line; @@ -266,7 +265,7 @@ private void handleResponseBody(@NotNull ResponseBody responseBody, ChatCompleti } try { - MessageChunk messageChunk = client.getObjectMapper().readValue(chunk, MessageChunk.class); + MessageChunk messageChunk = this.objectMapper.readValue(chunk, MessageChunk.class); callback.onChunkReceived(messageChunk); } catch (JsonProcessingException e) {