Skip to content

Commit 98991c0

Browse files
authored
Enable compatibility with Azure OpenAI (#415)
1 parent 62d14db commit 98991c0

File tree

14 files changed

+300
-270
lines changed

14 files changed

+300
-270
lines changed

core-codemods/src/main/java/io/codemodder/codemods/LogFailedLoginCodemod.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ public LogFailedLoginCodemod(
6767
When no logger is in scope, the new code emits a log message to the console.
6868
"""
6969
.replace('\n', ' '))),
70-
StandardModel.GPT_4O,
71-
StandardModel.GPT_4);
70+
StandardModel.GPT_4O_2024_05_13,
71+
StandardModel.GPT_4_0613);
7272
}
7373

7474
@Override

core-codemods/src/main/java/io/codemodder/codemods/SensitiveDataLoggingCodemod.java

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.codemodder.codemods;
22

3+
import com.azure.ai.openai.models.ChatRequestUserMessage;
34
import com.contrastsecurity.sarif.PhysicalLocation;
45
import com.contrastsecurity.sarif.Result;
56
import com.contrastsecurity.sarif.Run;
@@ -8,10 +9,10 @@
89
import com.fasterxml.jackson.databind.ObjectReader;
910
import com.github.javaparser.ast.CompilationUnit;
1011
import com.github.javaparser.ast.stmt.Statement;
11-
import com.theokanning.openai.completion.chat.*;
1212
import io.codemodder.*;
1313
import io.codemodder.javaparser.JavaParserChanger;
1414
import io.codemodder.plugins.llm.OpenAIService;
15+
import io.codemodder.plugins.llm.StandardModel;
1516
import io.codemodder.providers.sarif.semgrep.SemgrepScan;
1617
import java.io.IOException;
1718
import java.io.UncheckedIOException;
@@ -98,21 +99,10 @@ private SensitivityAndFixAnalysis performSensitivityAnalysis(
9899
"""
99100
.formatted(startLine, codeSnippet);
100101

101-
ChatCompletionRequest request =
102-
ChatCompletionRequest.builder()
103-
.temperature(0D)
104-
.model("gpt-4o-2024-05-13")
105-
.n(1)
106-
.messages(List.of(new ChatMessage(ChatMessageRole.USER.value(), prompt)))
107-
.build();
108-
ChatCompletionResult completion = service.createChatCompletion(request);
109-
ChatCompletionChoice response = completion.getChoices().get(0);
110-
String responseText = response.getMessage().getContent();
111-
if (responseText.startsWith("```json") && responseText.endsWith("```")) {
112-
responseText =
113-
responseText.substring("```json".length(), responseText.length() - "```".length());
114-
}
115-
return reader.readValue(responseText);
102+
return service.getResponseForPrompt(
103+
List.of(new ChatRequestUserMessage(prompt)),
104+
StandardModel.GPT_4O_2024_05_13,
105+
SensitivityAndFixAnalysisDTO.class);
116106
}
117107

118108
/**

framework/codemodder-base/build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ dependencies {
2929
api(libs.javaparser.symbolsolver.model)
3030
api(libs.javadiff)
3131
api(libs.jtokkit)
32-
api(libs.openai.service)
32+
api("com.azure:azure-ai-openai:1.0.0-beta.10")
3333
api("io.github.classgraph:classgraph:4.8.160")
3434

3535
implementation(libs.tuples)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package io.codemodder.plugins.llm;
2+
3+
import java.util.HashMap;
4+
5+
/** Mapper that maps models to their deployment names based on environment variables. */
6+
final class EnvironmentBasedModelMapper implements ModelMapper {
7+
private static final String DEPLOYMENT_TEMPLATE = "CODEMODDER_AZURE_OPENAI_%s_DEPLOYMENT";
8+
9+
private final HashMap<Model, String> map = new HashMap<>();
10+
11+
EnvironmentBasedModelMapper() {
12+
for (Model m : StandardModel.values()) {
13+
final var deployment = System.getenv(String.format(DEPLOYMENT_TEMPLATE, m));
14+
map.put(m, deployment == null ? m.id() : deployment);
15+
}
16+
}
17+
18+
@Override
19+
public String getModelName(Model model) {
20+
return map.get(model);
21+
}
22+
}

plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/LLMServiceModule.java

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,36 @@
55
/** Provides configured LLM services. */
66
public final class LLMServiceModule extends AbstractModule {
77

8-
private static final String TOKEN_NAME = "CODEMODDER_OPENAI_API_KEY";
8+
private static final String OPENAI_KEY_NAME = "CODEMODDER_OPENAI_API_KEY";
9+
private static final String AZURE_OPENAI_KEY_NAME = "CODEMODDER_AZURE_OPENAI_API_KEY";
10+
private static final String AZURE_OPENAI_ENDPOINT = "CODEMODDER_AZURE_OPENAI_ENDPOINT";
911

1012
@Override
1113
protected void configure() {
12-
bind(OpenAIService.class).toProvider(() -> new OpenAIService(getToken()));
14+
final var azureOpenAIKey = System.getenv(AZURE_OPENAI_KEY_NAME);
15+
final var azureOpenAIEndpoint = System.getenv(AZURE_OPENAI_ENDPOINT);
16+
if ((azureOpenAIEndpoint == null) != (azureOpenAIKey == null)) {
17+
throw new IllegalArgumentException(
18+
"Both or neither of "
19+
+ AZURE_OPENAI_KEY_NAME
20+
+ " and "
21+
+ AZURE_OPENAI_ENDPOINT
22+
+ " must be set");
23+
}
24+
if (azureOpenAIKey != null) {
25+
bind(OpenAIService.class)
26+
.toProvider(() -> OpenAIService.fromAzureOpenAI(azureOpenAIKey, azureOpenAIEndpoint));
27+
return;
28+
}
29+
30+
bind(OpenAIService.class).toProvider(() -> OpenAIService.fromOpenAI(getOpenAIToken()));
1331
}
1432

15-
private String getToken() {
16-
String token = System.getenv(TOKEN_NAME);
17-
if (token == null) {
18-
throw new IllegalArgumentException(TOKEN_NAME + " environment variable must be set");
33+
private String getOpenAIToken() {
34+
final var openAIKey = System.getenv(OPENAI_KEY_NAME);
35+
if (openAIKey == null) {
36+
throw new IllegalArgumentException(OPENAI_KEY_NAME + " environment variable must be set");
1937
}
20-
return token;
38+
return openAIKey;
2139
}
2240
}

plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/Model.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package io.codemodder.plugins.llm;
22

3-
import com.theokanning.openai.completion.chat.ChatMessage;
43
import java.util.List;
54

65
/**
@@ -26,5 +25,5 @@ public interface Model {
2625
* @param messages the list of messages for which to estimate token usage
2726
* @return estimated tokens that would be consumed by the model
2827
*/
29-
int tokens(List<ChatMessage> messages);
28+
int tokens(List<String> messages);
3029
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package io.codemodder.plugins.llm;
2+
3+
/** Maps models to their deployment names. */
4+
interface ModelMapper {
5+
/**
6+
* Maps the given model to its deployment name.
7+
*
8+
* @param model the model to map
9+
* @return the deployment name of the model
10+
*/
11+
String getModelName(Model model);
12+
}
Lines changed: 101 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,118 @@
11
package io.codemodder.plugins.llm;
22

3-
import com.theokanning.openai.client.OpenAiApi;
4-
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
5-
import com.theokanning.openai.completion.chat.ChatCompletionResult;
6-
import com.theokanning.openai.service.OpenAiService;
7-
import io.reactivex.Flowable;
8-
import io.reactivex.functions.Function;
9-
import java.net.SocketTimeoutException;
3+
import com.azure.ai.openai.OpenAIClient;
4+
import com.azure.ai.openai.OpenAIClientBuilder;
5+
import com.azure.ai.openai.models.ChatCompletions;
6+
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
7+
import com.azure.ai.openai.models.ChatCompletionsOptions;
8+
import com.azure.ai.openai.models.ChatRequestMessage;
9+
import com.azure.core.credential.AzureKeyCredential;
10+
import com.azure.core.credential.KeyCredential;
11+
import com.azure.core.http.policy.RetryPolicy;
12+
import com.azure.core.util.HttpClientOptions;
13+
import com.fasterxml.jackson.databind.ObjectMapper;
14+
import java.io.IOException;
1015
import java.time.Duration;
11-
import java.util.concurrent.TimeUnit;
12-
import org.slf4j.Logger;
13-
import org.slf4j.LoggerFactory;
14-
import retrofit2.HttpException;
16+
import java.util.List;
17+
import java.util.Objects;
1518

16-
/**
17-
* A custom service class to call the {@link OpenAiApi}, since the out-of-the box {@link
18-
* OpenAiService} <a href="https://github.com/TheoKanning/openai-java/issues/189">does not support
19-
* automatic retries</a>.
20-
*/
19+
/** A custom service class to wrap the {@link OpenAIClient} */
2120
public class OpenAIService {
22-
private final OpenAiApi api;
21+
private final OpenAIClient api;
22+
private static final int TIMEOUT_SECONDS = 90;
23+
private final ModelMapper modelMapper;
2324

24-
public OpenAIService(final String token) {
25-
this.api = OpenAiService.buildApi(token, Duration.ofSeconds(90));
25+
private static OpenAIClientBuilder builder(final KeyCredential key) {
26+
HttpClientOptions clientOptions = new HttpClientOptions();
27+
clientOptions.setReadTimeout(Duration.ofSeconds(TIMEOUT_SECONDS));
28+
return new OpenAIClientBuilder()
29+
.retryPolicy(new RetryPolicy())
30+
.clientOptions(clientOptions)
31+
.credential(key);
2632
}
2733

28-
public ChatCompletionResult createChatCompletion(final ChatCompletionRequest request) {
29-
return this.api
30-
.createChatCompletion(request)
31-
.retryWhen(new OpenAIRetryStrategy())
32-
.blockingGet();
34+
OpenAIService(final ModelMapper mapper, final KeyCredential key) {
35+
this.modelMapper = mapper;
36+
this.api = builder(key).buildClient();
37+
}
38+
39+
OpenAIService(final ModelMapper mapper, final KeyCredential key, final String endpoint) {
40+
this.modelMapper = mapper;
41+
this.api = builder(key).endpoint(endpoint).buildClient();
3342
}
34-
}
3543

36-
/**
37-
* When there is a retryable error from OpenAI -- either a timeout or a retryable <a
38-
* href="https://platform.openai.com/docs/guides/error-codes/api-errors">error code</a> -- this will
39-
* retry the request up to 3 times, with a delay of {@code 1 * retryCount} seconds between retries.
40-
*/
41-
class OpenAIRetryStrategy implements Function<Flowable<? extends Throwable>, Flowable<Object>> {
42-
private static final int MAX_RETRY_COUNT = 3;
43-
private static final Logger logger = LoggerFactory.getLogger(OpenAIRetryStrategy.class);
44+
/**
45+
* Creates a new {@link OpenAIService} instance with the given OpenAI token.
46+
*
47+
* @param token the token to use
48+
* @return the new instance
49+
*/
50+
public static OpenAIService fromOpenAI(final String token) {
51+
return new OpenAIService(
52+
new EnvironmentBasedModelMapper(), new KeyCredential(Objects.requireNonNull(token)));
53+
}
4454

45-
private int retryCount = 0;
55+
/**
56+
* Creates a new {@link OpenAIService} instance with the given Azure OpenAI token and endpoint.
57+
*
58+
* @param token the token to use
59+
* @param endpoint the endpoint to use
60+
* @return the new instance
61+
*/
62+
public static OpenAIService fromAzureOpenAI(final String token, final String endpoint) {
63+
return new OpenAIService(
64+
new EnvironmentBasedModelMapper(),
65+
new AzureKeyCredential(Objects.requireNonNull(token)),
66+
Objects.requireNonNull(endpoint));
67+
}
4668

47-
@Override
48-
public Flowable<Object> apply(final Flowable<? extends Throwable> flowable) {
49-
return flowable.flatMap(
50-
e -> {
51-
if (++retryCount <= MAX_RETRY_COUNT && isRetryable(e)) {
52-
logger.warn("retrying after {}s: {}", retryCount, e);
53-
return Flowable.timer(retryCount, TimeUnit.SECONDS);
54-
} else {
55-
return Flowable.error(e);
56-
}
57-
});
69+
/**
70+
* Gets the completion for the given messages.
71+
*
72+
* @param messages the messages
73+
* @param modelOrDeploymentName the model or deployment name
74+
* @return the completion
75+
*/
76+
public String getJSONCompletion(
77+
final List<ChatRequestMessage> messages, final Model modelOrDeploymentName) {
78+
ChatCompletionsOptions options =
79+
new ChatCompletionsOptions(messages)
80+
.setTemperature(0D)
81+
.setN(1)
82+
.setResponseFormat(new ChatCompletionsJsonResponseFormat());
83+
final var modelName = modelMapper.getModelName(modelOrDeploymentName);
84+
ChatCompletions completions = this.api.getChatCompletions(modelName, options);
85+
return completions.getChoices().get(0).getMessage().getContent().trim();
5886
}
5987

60-
private boolean isRetryable(final Throwable e) {
61-
if (e instanceof SocketTimeoutException) {
62-
return true;
63-
} else if (e instanceof HttpException) {
64-
int code = ((HttpException) e).code();
65-
return code == 429 || code == 500 || code == 503;
88+
/**
89+
* Returns an object of the given type based on the completion for the given messages.
90+
*
91+
* @param messages the messages
92+
* @param modelName the model name
93+
* @return the completion
94+
*/
95+
public <T> T getResponseForPrompt(
96+
final List<ChatRequestMessage> messages, final Model modelName, final Class<T> responseType)
97+
throws IOException {
98+
String json = getJSONCompletion(messages, modelName);
99+
100+
// we've seen this with turbo/lesser models
101+
if (json.startsWith("```json") && json.endsWith("```")) {
102+
json = json.substring(7, json.length() - 3);
103+
}
104+
105+
// try to deserialize the content as is
106+
ObjectMapper mapper = new ObjectMapper();
107+
try {
108+
return mapper.readValue(json, responseType);
109+
} catch (IOException e) {
110+
int firstBorder = json.indexOf("```json");
111+
int lastBorder = json.lastIndexOf("```");
112+
if (firstBorder != -1 && lastBorder != -1 && lastBorder > firstBorder) {
113+
json = json.substring(firstBorder + 7, lastBorder);
114+
}
115+
return mapper.readValue(json, responseType);
66116
}
67-
return false;
68117
}
69118
}

0 commit comments

Comments
 (0)