Skip to content
Merged
4 changes: 2 additions & 2 deletions src/main/java/nl/dannyj/mistral/MistralClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ public void createChatCompletionStream(@NonNull ChatCompletionRequest request, @
* @return A new instance of MistralService
*/
private MistralService buildMistralService() {
return new MistralService(this, new HttpService(this));
return new MistralService(new HttpService(this.httpClient), this.objectMapper);
}

/**
Expand All @@ -260,7 +260,7 @@ private MistralService buildMistralService() {
* @return A new instance of OkHttpClient
*/
private OkHttpClient buildHttpClient(int readTimeoutSeconds, int connectTimeoutSeconds, int writeTimeoutSeconds) {
MistralHeaderInterceptor mistralInterceptor = new MistralHeaderInterceptor(this);
MistralHeaderInterceptor mistralInterceptor = new MistralHeaderInterceptor(this.getApiKey());

return new OkHttpClient.Builder()
.readTimeout(readTimeoutSeconds, TimeUnit.SECONDS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package nl.dannyj.mistral.interceptors;

import lombok.NonNull;
import nl.dannyj.mistral.MistralClient;
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;
Expand All @@ -27,10 +25,14 @@

public class MistralHeaderInterceptor implements Interceptor {

private final MistralClient client;
private final String apiKey;

public MistralHeaderInterceptor(@NonNull MistralClient client) {
this.client = client;
public MistralHeaderInterceptor(@NotNull String apiKey) {
if (apiKey == null || apiKey.isBlank()) {
throw new IllegalArgumentException("No API key provided in MistralClient");
}

this.apiKey = apiKey;
}

@NotNull
Expand All @@ -39,10 +41,6 @@ public Response intercept(@NotNull Chain chain) throws IOException {
Request request = chain.request();
Request.Builder newRequestBuilder = request.newBuilder();

if (client.getApiKey() == null || client.getApiKey().isBlank()) {
throw new IllegalArgumentException("No API key provided in MistralClient");
}

if (request.header("Content-Type") == null) {
newRequestBuilder.addHeader("Content-Type", "application/json");
}
Expand All @@ -52,7 +50,7 @@ public Response intercept(@NotNull Chain chain) throws IOException {
}

if (request.header("Authorization") == null) {
newRequestBuilder.addHeader("Authorization", "Bearer " + client.getApiKey());
newRequestBuilder.addHeader("Authorization", "Bearer " + this.apiKey);
}

Request newRequest = newRequestBuilder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ public abstract class ChatMessage {
/**
* The content of the message. Can be null or a list of content chunks.
*
* @param content The list of content chunks, or null.
* @return The list of content chunks, or null.
*/
@Nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

public class ContentChunkListDeserializer extends StdDeserializer<List<ContentChunk>> implements ContextualDeserializer {

private JsonDeserializer<?> defaultDeserializer;
private transient JsonDeserializer<?> defaultDeserializer;

public ContentChunkListDeserializer() {
this(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,8 @@ public ToolChoiceOption deserialize(JsonParser jp, DeserializationContext ctxt)

if (token == JsonToken.VALUE_STRING) {
String enumValue = jp.getText().toUpperCase();
try {
return ToolChoiceEnum.valueOf(enumValue);
} catch (IllegalArgumentException e) {
if ("ANY".equalsIgnoreCase(enumValue)) return ToolChoiceEnum.ANY;
if ("AUTO".equalsIgnoreCase(enumValue)) return ToolChoiceEnum.AUTO;
if ("NONE".equalsIgnoreCase(enumValue)) return ToolChoiceEnum.NONE;
if ("REQUIRED".equalsIgnoreCase(enumValue)) return ToolChoiceEnum.REQUIRED;

throw ctxt.weirdStringException(enumValue, ToolChoiceEnum.class, "Not a valid ToolChoiceEnum value");
}

return ToolChoiceEnum.valueOf(enumValue);
} else if (token == JsonToken.START_OBJECT) {
return mapper.readValue(jp, SpecificToolChoice.class);
}
Expand Down
12 changes: 4 additions & 8 deletions src/main/java/nl/dannyj/mistral/services/HttpService.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package nl.dannyj.mistral.services;

import lombok.NonNull;
import nl.dannyj.mistral.MistralClient;
import nl.dannyj.mistral.exceptions.MistralAPIException;
import okhttp3.Callback;
import okhttp3.MediaType;
Expand All @@ -37,15 +36,15 @@ public class HttpService {

private static final String API_URL = "https://api.mistral.ai/v1";

private final MistralClient client;
private final OkHttpClient httpClient;

/**
* Constructor that initializes the HttpService with a provided MistralClient.
*
* @param client The MistralClient to be used for making requests to the Mistral AI API
* @param httpClient The OkHttpClient to be used for making requests to the Mistral AI API
*/
public HttpService(@NonNull MistralClient client) {
this.client = client;
public HttpService(@NonNull OkHttpClient httpClient) {
this.httpClient = httpClient;
}

/**
Expand Down Expand Up @@ -91,7 +90,6 @@ public void streamPost(@NonNull String urlPath, @NonNull String body, Callback c
.url(API_URL + urlPath)
.post(RequestBody.create(body, MediaType.parse("application/json")))
.build();
OkHttpClient httpClient = client.getHttpClient();

httpClient.newCall(request).enqueue(callBack);
}
Expand All @@ -104,8 +102,6 @@ public void streamPost(@NonNull String urlPath, @NonNull String body, Callback c
* @throws MistralAPIException If the response is not successful, the response body is null or an IOException occurs in the objectmapper
*/
private String executeRequest(Request request) {
OkHttpClient httpClient = client.getHttpClient();

try (Response response = httpClient.newCall(request).execute()) {
if (!response.isSuccessful()) {
throw new MistralAPIException("Received unexpected response code " + response.code() + ": " + (response.body() != null ? response.body().string() : response));
Expand Down
22 changes: 11 additions & 11 deletions src/main/java/nl/dannyj/mistral/services/MistralService.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
package nl.dannyj.mistral.services;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.validation.ConstraintViolation;
import jakarta.validation.ConstraintViolationException;
import jakarta.validation.Validation;
import jakarta.validation.Validator;
import jakarta.validation.ValidatorFactory;
import lombok.NonNull;
import nl.dannyj.mistral.MistralClient;
import nl.dannyj.mistral.exceptions.InvalidJsonException;
import nl.dannyj.mistral.exceptions.UnexpectedResponseEndException;
import nl.dannyj.mistral.exceptions.UnexpectedResponseException;
Expand Down Expand Up @@ -55,17 +55,17 @@
public class MistralService {

private final HttpService httpService;
private final MistralClient client;
private final ObjectMapper objectMapper;
private final Validator validator;

/**
* Constructor that initializes the MistralService with a provided MistralClient and HttpService.
*
* @param client The MistralClient to be used for interacting with the Mistral AI API
* @param httpService The HttpService to be used for making HTTP requests
* @param httpService The HttpService to be used for making HTTP requests to the Mistral AI API
* @param objectMapper The ObjectMapper to be used for converting objects to and from JSON
*/
public MistralService(@NonNull MistralClient client, @NonNull HttpService httpService) {
this.client = client;
public MistralService(@NonNull HttpService httpService, @NonNull ObjectMapper objectMapper) {
this.objectMapper = objectMapper;
this.httpService = httpService;

try (ValidatorFactory validatorFactory = Validation.buildDefaultValidatorFactory()) {
Expand Down Expand Up @@ -118,7 +118,7 @@ public void createChatCompletionStream(@NonNull ChatCompletionRequest request, @
validateRequest(request);

try {
String requestJson = client.getObjectMapper().writeValueAsString(request);
String requestJson = this.objectMapper.writeValueAsString(request);

httpService.streamPost("/chat/completions", requestJson, new Callback() {
@Override
Expand Down Expand Up @@ -161,7 +161,7 @@ public ListModelsResponse listModels() {
String response = httpService.get("/models");

try {
return client.getObjectMapper().readValue(response, ListModelsResponse.class);
return this.objectMapper.readValue(response, ListModelsResponse.class);
} catch (JsonProcessingException e) {
throw new UnexpectedResponseException("Received unexpected response from the Mistral.ai API (mistral-java-client might need to be updated): " + response, e);
}
Expand Down Expand Up @@ -238,15 +238,15 @@ private <T extends Request, U extends Response> U postRequest(String endpoint, T
String requestJson = null;

try {
requestJson = client.getObjectMapper().writeValueAsString(request);
requestJson = this.objectMapper.writeValueAsString(request);
} catch (JsonProcessingException e) {
throw new InvalidJsonException("Failed to convert request to JSON", e);
}

try {
response = httpService.post(endpoint, requestJson);

return client.getObjectMapper().readValue(response, responseType);
return this.objectMapper.readValue(response, responseType);
} catch (JsonProcessingException e) {
throw new UnexpectedResponseException("Received unexpected response from the Mistral.ai API (mistral-java-client might need to be updated): " + response, e);
}
Expand All @@ -266,7 +266,7 @@ private void handleResponseBody(@NotNull ResponseBody responseBody, ChatCompleti
}

try {
MessageChunk messageChunk = client.getObjectMapper().readValue(chunk, MessageChunk.class);
MessageChunk messageChunk = this.objectMapper.readValue(chunk, MessageChunk.class);

callback.onChunkReceived(messageChunk);
} catch (JsonProcessingException e) {
Expand Down