Skip to content

Commit f5a2a16

Browse files
committed
google/gemini provider cleanup. replaced 3.0 pro preview with 3.1 pro preview. fixed issues with 2.5 models
1 parent 6d1bba3 commit f5a2a16

File tree

3 files changed

+185
-13
lines changed

3 files changed

+185
-13
lines changed

src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java

Lines changed: 136 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package ee.carlrobert.codegpt.completions;
22

3+
import com.fasterxml.jackson.core.JsonProcessingException;
4+
import com.fasterxml.jackson.databind.node.ObjectNode;
35
import com.intellij.openapi.application.ApplicationManager;
46
import com.intellij.openapi.components.Service;
57
import ee.carlrobert.codegpt.completions.factory.CustomOpenAIRequest;
@@ -12,7 +14,11 @@
1214
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest;
1315
import ee.carlrobert.llm.client.codegpt.request.InlineEditRequest;
1416
import ee.carlrobert.llm.client.codegpt.request.chat.ChatCompletionRequest;
17+
import ee.carlrobert.llm.client.google.completion.ApiResponseError;
1518
import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest;
19+
import ee.carlrobert.llm.client.google.completion.GoogleCompletionResponse;
20+
import ee.carlrobert.llm.client.google.completion.GoogleContentPart;
21+
import ee.carlrobert.llm.client.google.models.GoogleModel;
1622
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
1723
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
1824
import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener;
@@ -21,6 +27,7 @@
2127
import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponseChoice;
2228
import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponseChoiceDelta;
2329
import ee.carlrobert.llm.completion.CompletionEventListener;
30+
import ee.carlrobert.llm.completion.CompletionEventSourceListener;
2431
import ee.carlrobert.llm.completion.CompletionRequest;
2532
import java.io.IOException;
2633
import java.util.Collection;
@@ -30,6 +37,8 @@
3037
import java.util.stream.Stream;
3138
import okhttp3.Call;
3239
import okhttp3.Callback;
40+
import okhttp3.HttpUrl;
41+
import okhttp3.MediaType;
3342
import okhttp3.Request;
3443
import okhttp3.RequestBody;
3544
import okhttp3.Response;
@@ -41,6 +50,10 @@
4150
@Service
4251
public final class CompletionRequestService {
4352

53+
private static final String GOOGLE_BASE_URL =
54+
"https://generativelanguage.googleapis.com";
55+
private static final MediaType JSON_MEDIA_TYPE = MediaType.parse("application/json");
56+
4457
private CompletionRequestService() {
4558
}
4659

@@ -240,10 +253,12 @@ public EventSource getChatCompletionAsync(
240253
eventListener);
241254
}
242255
if (request instanceof GoogleCompletionRequest completionRequest) {
256+
var model = ModelSelectionService.getInstance().getModelForFeature(featureType, null);
257+
if (model != null && GoogleModel.findByCode(model) == null) {
258+
return getGoogleNonEnumModelCompletionAsync(completionRequest, model, eventListener);
259+
}
243260
return CompletionClientProvider.getGoogleClient().getChatCompletionAsync(
244-
completionRequest,
245-
ModelSelectionService.getInstance().getModelForFeature(featureType, null),
246-
eventListener);
261+
completionRequest, model, eventListener);
247262
}
248263

249264
throw new IllegalStateException("Unknown request type: " + request.getClass());
@@ -293,11 +308,14 @@ public String getChatCompletion(CompletionRequest request, ServiceType serviceTy
293308
.getText();
294309
}
295310
if (request instanceof GoogleCompletionRequest completionRequest) {
311+
var model = ApplicationManager.getApplication()
312+
.getService(ModelSelectionService.class)
313+
.getModelForFeature(featureType, null);
314+
if (model != null && GoogleModel.findByCode(model) == null) {
315+
return getGoogleNonEnumModelCompletion(completionRequest, model);
316+
}
296317
return CompletionClientProvider.getGoogleClient().getChatCompletion(
297-
completionRequest,
298-
ApplicationManager.getApplication()
299-
.getService(ModelSelectionService.class)
300-
.getModelForFeature(featureType, null))
318+
completionRequest, model)
301319
.getCandidates().get(0)
302320
.getContent().getParts().get(0)
303321
.getText();
@@ -333,6 +351,117 @@ public static boolean isRequestAllowed(ServiceType serviceType) {
333351
};
334352
}
335353

354+
private EventSource getGoogleNonEnumModelCompletionAsync(
355+
GoogleCompletionRequest request,
356+
String model,
357+
CompletionEventListener<String> eventListener) {
358+
try {
359+
var httpRequest = buildGoogleNonEnumRequest(model, "streamGenerateContent", request, true);
360+
var httpClient = CompletionClientProvider.getDefaultClientBuilder().build();
361+
return EventSources.createFactory(httpClient).newEventSource(
362+
httpRequest, createGoogleEventSourceListener(eventListener));
363+
} catch (JsonProcessingException e) {
364+
throw new RuntimeException("Failed to serialize Google completion request", e);
365+
}
366+
}
367+
368+
private String getGoogleNonEnumModelCompletion(
369+
GoogleCompletionRequest request,
370+
String model) {
371+
try {
372+
var httpRequest = buildGoogleNonEnumRequest(model, "generateContent", request, false);
373+
var httpClient = CompletionClientProvider.getDefaultClientBuilder().build();
374+
try (var response = httpClient.newCall(httpRequest).execute()) {
375+
return DeserializationUtil.mapResponse(response, GoogleCompletionResponse.class)
376+
.getCandidates().get(0)
377+
.getContent().getParts().get(0)
378+
.getText();
379+
}
380+
} catch (IOException e) {
381+
throw new RuntimeException("Failed to get Google completion", e);
382+
}
383+
}
384+
385+
private Request buildGoogleNonEnumRequest(
386+
String model, String action, Object requestBody, boolean stream)
387+
throws JsonProcessingException {
388+
var apiKey = CredentialsStore.INSTANCE.getCredential(CredentialKey.GoogleApiKey.INSTANCE);
389+
var urlBuilder = HttpUrl.parse(
390+
GOOGLE_BASE_URL + "/v1beta/models/" + model + ":" + action).newBuilder();
391+
if (apiKey != null && !apiKey.isEmpty()) {
392+
urlBuilder.addQueryParameter("key", apiKey);
393+
}
394+
if (stream) {
395+
urlBuilder.addQueryParameter("alt", "sse");
396+
}
397+
398+
var mapper = DeserializationUtil.OBJECT_MAPPER;
399+
var jsonNode = (ObjectNode) mapper.valueToTree(requestBody);
400+
401+
// Inject thinkingConfig for models that support thinking (3.x+)
402+
var genConfig = jsonNode.has("generationConfig")
403+
? (ObjectNode) jsonNode.get("generationConfig")
404+
: mapper.createObjectNode();
405+
if (!genConfig.has("thinkingConfig")) {
406+
var thinkingConfig = mapper.createObjectNode();
407+
thinkingConfig.put("thinkingLevel", "low");
408+
genConfig.set("thinkingConfig", thinkingConfig);
409+
}
410+
if (!jsonNode.has("generationConfig")) {
411+
jsonNode.set("generationConfig", genConfig);
412+
}
413+
414+
return new Request.Builder()
415+
.url(urlBuilder.build())
416+
.header("Cache-Control", "no-cache")
417+
.header("Content-Type", "application/json")
418+
.header("Accept", stream ? "text/event-stream" : "text/json")
419+
.post(RequestBody.create(mapper.writeValueAsString(jsonNode), JSON_MEDIA_TYPE))
420+
.build();
421+
}
422+
423+
private CompletionEventSourceListener<String> createGoogleEventSourceListener(
424+
CompletionEventListener<String> eventListener) {
425+
return new CompletionEventSourceListener<>(eventListener) {
426+
@Override
427+
protected String getMessage(String data) {
428+
try {
429+
var candidates = DeserializationUtil.OBJECT_MAPPER
430+
.readValue(data, GoogleCompletionResponse.class)
431+
.getCandidates();
432+
return (candidates == null
433+
? Stream.<GoogleCompletionResponse.Candidate>empty()
434+
: candidates.stream())
435+
.filter(Objects::nonNull)
436+
.flatMap(candidate -> {
437+
if (candidate.getContent() != null
438+
&& candidate.getContent().getParts() != null) {
439+
return candidate.getContent().getParts().stream();
440+
}
441+
return Stream.empty();
442+
})
443+
.filter(Objects::nonNull)
444+
.filter(part -> part.getThought() == null || !part.getThought())
445+
.findFirst()
446+
.map(GoogleContentPart::getText)
447+
.orElse("");
448+
} catch (JsonProcessingException e) {
449+
// ignore
450+
}
451+
return "";
452+
}
453+
454+
@Override
455+
protected ErrorDetails getErrorDetails(String data) throws JsonProcessingException {
456+
var googleError = DeserializationUtil.OBJECT_MAPPER
457+
.readValue(data, ApiResponseError.class).getError();
458+
return googleError == null ? null
459+
: new ErrorDetails(googleError.getMessage(), googleError.getStatus(), null,
460+
googleError.getCode());
461+
}
462+
};
463+
}
464+
336465
/**
337466
* Content of the first choice.
338467
* <ul>

src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,30 @@ class GoogleRequestFactory : BaseRequestFactory() {
2727
override fun createChatRequest(params: ChatCompletionParameters): GoogleCompletionRequest {
2828
val configuration = service<ConfigurationSettings>().state
2929
val selectedModel = ModelSelectionService.getInstance().getModelForFeature(FeatureType.CHAT)
30-
val messages = buildGoogleMessages(selectedModel, params)
30+
val systemInstruction = buildSystemInstruction(params)
31+
32+
val messages = if (!systemInstruction.isNullOrBlank()) {
33+
listOf(
34+
GoogleCompletionContent("user", listOf(systemInstruction)),
35+
GoogleCompletionContent("model", listOf("Understood."))
36+
) + buildGoogleMessages(selectedModel, params)
37+
} else {
38+
buildGoogleMessages(selectedModel, params)
39+
}
40+
3141
return GoogleCompletionRequest.Builder(messages)
3242
.generationConfig(
3343
GoogleGenerationConfig.Builder()
3444
.maxOutputTokens(configuration.maxTokens)
3545
.temperature(configuration.temperature.toDouble()).build()
3646
)
37-
.systemInstruction(buildSystemInstruction(params))
3847
.build()
3948
}
4049

50+
private fun isNonEnumModel(model: String?): Boolean {
51+
return model != null && GoogleModel.findByCode(model) == null
52+
}
53+
4154
override fun createBasicCompletionRequest(
4255
systemPrompt: String,
4356
userPrompt: String,
@@ -104,6 +117,8 @@ class GoogleRequestFactory : BaseRequestFactory() {
104117
break
105118
}
106119

120+
if (prevMessage.response.isNullOrBlank()) continue
121+
107122
prevMessage.imageFilePath?.takeIf { it.isNotEmpty() }?.let { imagePath ->
108123
try {
109124
val imageData = Files.readAllBytes(Path.of(imagePath))

src/main/kotlin/ee/carlrobert/codegpt/settings/models/ModelRegistry.kt

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import ai.koog.prompt.executor.clients.anthropic.AnthropicModels.Opus_4_5
55
import ai.koog.prompt.executor.clients.anthropic.AnthropicModels.Sonnet_4_5
66
import ai.koog.prompt.executor.clients.google.GoogleModels.Gemini2_5Flash
77
import ai.koog.prompt.executor.clients.google.GoogleModels.Gemini2_5Pro
8-
import ai.koog.prompt.executor.clients.google.GoogleModels.Gemini3_Pro_Preview
98
import ai.koog.prompt.executor.clients.mistralai.MistralAIModels.Chat.DevstralMedium
109
import ai.koog.prompt.executor.clients.openai.OpenAIModels.Chat.GPT4_1
1110
import ai.koog.prompt.executor.clients.openai.OpenAIModels.Chat.GPT4_1Mini
@@ -458,7 +457,7 @@ class ModelRegistry {
458457
)
459458
},
460459
ServiceType.GOOGLE to listOf(
461-
LLMModelWrapper(Gemini3_Pro_Preview, name = "Gemini 3 Pro Preview"),
460+
LLMModelWrapper(Gemini3_1_Pro_Preview, name = "Gemini 3.1 Pro Preview"),
462461
LLMModelWrapper(Gemini3_Flash_Preview, name = "Gemini 3 Flash Preview"),
463462
LLMModelWrapper(Gemini2_5Pro, name = "Gemini 2.5 Pro"),
464463
LLMModelWrapper(Gemini2_5Flash, name = "Gemini 2.5 Flash")
@@ -564,8 +563,8 @@ class ModelRegistry {
564563
return listOf(
565564
ModelSelection(
566565
ServiceType.GOOGLE,
567-
Gemini3_Pro_Preview.id,
568-
"Gemini 3 Pro Preview",
566+
Gemini3_1_Pro_Preview.id,
567+
"Gemini 3.1 Pro Preview",
569568
Icons.Google
570569
),
571570
ModelSelection(
@@ -831,6 +830,16 @@ class ModelRegistry {
831830

832831
private fun getGoogleModels(): List<ModelSelection> {
833832
return listOf(
833+
ModelSelection(
834+
ServiceType.GOOGLE,
835+
Gemini3_1_Pro_Preview.id,
836+
"Gemini 3.1 Pro Preview"
837+
),
838+
ModelSelection(
839+
ServiceType.GOOGLE,
840+
Gemini3_Flash_Preview.id,
841+
"Gemini 3 Flash Preview"
842+
),
834843
ModelSelection(
835844
ServiceType.GOOGLE,
836845
GoogleModel.GEMINI_2_5_PRO_PREVIEW.code,
@@ -996,6 +1005,25 @@ class ModelRegistry {
9961005
}
9971006
}
9981007

1008+
public val Gemini3_1_Pro_Preview: LLModel = LLModel(
1009+
provider = LLMProvider.Google,
1010+
id = "gemini-3.1-pro-preview",
1011+
capabilities = listOf(
1012+
LLMCapability.Temperature,
1013+
LLMCapability.Completion,
1014+
LLMCapability.MultipleChoices,
1015+
LLMCapability.Vision.Image,
1016+
LLMCapability.Vision.Video,
1017+
LLMCapability.Audio,
1018+
LLMCapability.Tools,
1019+
LLMCapability.ToolChoice,
1020+
LLMCapability.Schema.JSON.Basic,
1021+
LLMCapability.Schema.JSON.Standard,
1022+
),
1023+
contextLength = 1_048_576,
1024+
maxOutputTokens = 65_536,
1025+
)
1026+
9991027
public val Gemini3_Flash_Preview: LLModel = LLModel(
10001028
provider = LLMProvider.Google,
10011029
id = "gemini-3-flash-preview",

0 commit comments

Comments
 (0)