|
1 | 1 | package io.codemodder.plugins.llm; |
2 | 2 |
|
3 | | -import com.theokanning.openai.client.OpenAiApi; |
4 | | -import com.theokanning.openai.completion.chat.ChatCompletionRequest; |
5 | | -import com.theokanning.openai.completion.chat.ChatCompletionResult; |
6 | | -import com.theokanning.openai.service.OpenAiService; |
7 | | -import io.reactivex.Flowable; |
8 | | -import io.reactivex.functions.Function; |
9 | | -import java.net.SocketTimeoutException; |
| 3 | +import com.azure.ai.openai.OpenAIClient; |
| 4 | +import com.azure.ai.openai.OpenAIClientBuilder; |
| 5 | +import com.azure.ai.openai.models.ChatCompletions; |
| 6 | +import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; |
| 7 | +import com.azure.ai.openai.models.ChatCompletionsOptions; |
| 8 | +import com.azure.ai.openai.models.ChatRequestMessage; |
| 9 | +import com.azure.core.credential.AzureKeyCredential; |
| 10 | +import com.azure.core.credential.KeyCredential; |
| 11 | +import com.azure.core.http.policy.RetryPolicy; |
| 12 | +import com.azure.core.util.HttpClientOptions; |
| 13 | +import com.fasterxml.jackson.databind.ObjectMapper; |
| 14 | +import java.io.IOException; |
10 | 15 | import java.time.Duration; |
11 | | -import java.util.concurrent.TimeUnit; |
12 | | -import org.slf4j.Logger; |
13 | | -import org.slf4j.LoggerFactory; |
14 | | -import retrofit2.HttpException; |
| 16 | +import java.util.List; |
| 17 | +import java.util.Objects; |
15 | 18 |
|
16 | | -/** |
17 | | - * A custom service class to call the {@link OpenAiApi}, since the out-of-the box {@link |
18 | | - * OpenAiService} <a href="https://github.com/TheoKanning/openai-java/issues/189">does not support |
19 | | - * automatic retries</a>. |
20 | | - */ |
| 19 | +/** A custom service class to wrap the {@link OpenAIClient} */ |
21 | 20 | public class OpenAIService { |
22 | | - private final OpenAiApi api; |
| 21 | + private final OpenAIClient api; |
| 22 | + private static final int TIMEOUT_SECONDS = 90; |
| 23 | + private final ModelMapper modelMapper; |
23 | 24 |
|
24 | | - public OpenAIService(final String token) { |
25 | | - this.api = OpenAiService.buildApi(token, Duration.ofSeconds(90)); |
| 25 | + private static OpenAIClientBuilder builder(final KeyCredential key) { |
| 26 | + HttpClientOptions clientOptions = new HttpClientOptions(); |
| 27 | + clientOptions.setReadTimeout(Duration.ofSeconds(TIMEOUT_SECONDS)); |
| 28 | + return new OpenAIClientBuilder() |
| 29 | + .retryPolicy(new RetryPolicy()) |
| 30 | + .clientOptions(clientOptions) |
| 31 | + .credential(key); |
26 | 32 | } |
27 | 33 |
|
28 | | - public ChatCompletionResult createChatCompletion(final ChatCompletionRequest request) { |
29 | | - return this.api |
30 | | - .createChatCompletion(request) |
31 | | - .retryWhen(new OpenAIRetryStrategy()) |
32 | | - .blockingGet(); |
| 34 | + OpenAIService(final ModelMapper mapper, final KeyCredential key) { |
| 35 | + this.modelMapper = mapper; |
| 36 | + this.api = builder(key).buildClient(); |
| 37 | + } |
| 38 | + |
| 39 | + OpenAIService(final ModelMapper mapper, final KeyCredential key, final String endpoint) { |
| 40 | + this.modelMapper = mapper; |
| 41 | + this.api = builder(key).endpoint(endpoint).buildClient(); |
33 | 42 | } |
34 | | -} |
35 | 43 |
|
36 | | -/** |
37 | | - * When there is a retryable error from OpenAI -- either a timeout or a retryable <a |
38 | | - * href="https://platform.openai.com/docs/guides/error-codes/api-errors">error code</a> -- this will |
39 | | - * retry the request up to 3 times, with a delay of {@code 1 * retryCount} seconds between retries. |
40 | | - */ |
41 | | -class OpenAIRetryStrategy implements Function<Flowable<? extends Throwable>, Flowable<Object>> { |
42 | | - private static final int MAX_RETRY_COUNT = 3; |
43 | | - private static final Logger logger = LoggerFactory.getLogger(OpenAIRetryStrategy.class); |
| 44 | + /** |
| 45 | + * Creates a new {@link OpenAIService} instance with the given OpenAI token. |
| 46 | + * |
| 47 | + * @param token the token to use |
| 48 | + * @return the new instance |
| 49 | + */ |
| 50 | + public static OpenAIService fromOpenAI(final String token) { |
| 51 | + return new OpenAIService( |
| 52 | + new EnvironmentBasedModelMapper(), new KeyCredential(Objects.requireNonNull(token))); |
| 53 | + } |
44 | 54 |
|
45 | | - private int retryCount = 0; |
| 55 | + /** |
| 56 | + * Creates a new {@link OpenAIService} instance with the given Azure OpenAI token and endpoint. |
| 57 | + * |
| 58 | + * @param token the token to use |
| 59 | + * @param endpoint the endpoint to use |
| 60 | + * @return the new instance |
| 61 | + */ |
| 62 | + public static OpenAIService fromAzureOpenAI(final String token, final String endpoint) { |
| 63 | + return new OpenAIService( |
| 64 | + new EnvironmentBasedModelMapper(), |
| 65 | + new AzureKeyCredential(Objects.requireNonNull(token)), |
| 66 | + Objects.requireNonNull(endpoint)); |
| 67 | + } |
46 | 68 |
|
47 | | - @Override |
48 | | - public Flowable<Object> apply(final Flowable<? extends Throwable> flowable) { |
49 | | - return flowable.flatMap( |
50 | | - e -> { |
51 | | - if (++retryCount <= MAX_RETRY_COUNT && isRetryable(e)) { |
52 | | - logger.warn("retrying after {}s: {}", retryCount, e); |
53 | | - return Flowable.timer(retryCount, TimeUnit.SECONDS); |
54 | | - } else { |
55 | | - return Flowable.error(e); |
56 | | - } |
57 | | - }); |
| 69 | + /** |
| 70 | + * Gets the completion for the given messages. |
| 71 | + * |
| 72 | + * @param messages the messages |
| 73 | + * @param modelOrDeploymentName the model or deployment name |
| 74 | + * @return the completion |
| 75 | + */ |
| 76 | + public String getJSONCompletion( |
| 77 | + final List<ChatRequestMessage> messages, final Model modelOrDeploymentName) { |
| 78 | + ChatCompletionsOptions options = |
| 79 | + new ChatCompletionsOptions(messages) |
| 80 | + .setTemperature(0D) |
| 81 | + .setN(1) |
| 82 | + .setResponseFormat(new ChatCompletionsJsonResponseFormat()); |
| 83 | + final var modelName = modelMapper.getModelName(modelOrDeploymentName); |
| 84 | + ChatCompletions completions = this.api.getChatCompletions(modelName, options); |
| 85 | + return completions.getChoices().get(0).getMessage().getContent().trim(); |
58 | 86 | } |
59 | 87 |
|
60 | | - private boolean isRetryable(final Throwable e) { |
61 | | - if (e instanceof SocketTimeoutException) { |
62 | | - return true; |
63 | | - } else if (e instanceof HttpException) { |
64 | | - int code = ((HttpException) e).code(); |
65 | | - return code == 429 || code == 500 || code == 503; |
| 88 | + /** |
| 89 | + * Returns an object of the given type based on the completion for the given messages. |
| 90 | + * |
| 91 | + * @param messages the messages |
| 92 | + * @param modelName the model name |
| 93 | + * @return the completion |
| 94 | + */ |
| 95 | + public <T> T getResponseForPrompt( |
| 96 | + final List<ChatRequestMessage> messages, final Model modelName, final Class<T> responseType) |
| 97 | + throws IOException { |
| 98 | + String json = getJSONCompletion(messages, modelName); |
| 99 | + |
| 100 | + // we've seen this with turbo/lesser models |
| 101 | + if (json.startsWith("```json") && json.endsWith("```")) { |
| 102 | + json = json.substring(7, json.length() - 3); |
| 103 | + } |
| 104 | + |
| 105 | + // try to deserialize the content as is |
| 106 | + ObjectMapper mapper = new ObjectMapper(); |
| 107 | + try { |
| 108 | + return mapper.readValue(json, responseType); |
| 109 | + } catch (IOException e) { |
| 110 | + int firstBorder = json.indexOf("```json"); |
| 111 | + int lastBorder = json.lastIndexOf("```"); |
| 112 | + if (firstBorder != -1 && lastBorder != -1 && lastBorder > firstBorder) { |
| 113 | + json = json.substring(firstBorder + 7, lastBorder); |
| 114 | + } |
| 115 | + return mapper.readValue(json, responseType); |
66 | 116 | } |
67 | | - return false; |
68 | 117 | } |
69 | 118 | } |
0 commit comments