|
1 | 1 | package ee.carlrobert.codegpt.completions; |
2 | 2 |
|
| 3 | +import com.fasterxml.jackson.core.JsonProcessingException; |
| 4 | +import com.fasterxml.jackson.databind.node.ObjectNode; |
3 | 5 | import com.intellij.openapi.application.ApplicationManager; |
4 | 6 | import com.intellij.openapi.components.Service; |
5 | 7 | import ee.carlrobert.codegpt.completions.factory.CustomOpenAIRequest; |
|
12 | 14 | import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest; |
13 | 15 | import ee.carlrobert.llm.client.codegpt.request.InlineEditRequest; |
14 | 16 | import ee.carlrobert.llm.client.codegpt.request.chat.ChatCompletionRequest; |
| 17 | +import ee.carlrobert.llm.client.google.completion.ApiResponseError; |
15 | 18 | import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest; |
| 19 | +import ee.carlrobert.llm.client.google.completion.GoogleCompletionResponse; |
| 20 | +import ee.carlrobert.llm.client.google.completion.GoogleContentPart; |
| 21 | +import ee.carlrobert.llm.client.google.models.GoogleModel; |
16 | 22 | import ee.carlrobert.llm.client.openai.completion.ErrorDetails; |
17 | 23 | import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener; |
18 | 24 | import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener; |
|
21 | 27 | import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponseChoice; |
22 | 28 | import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponseChoiceDelta; |
23 | 29 | import ee.carlrobert.llm.completion.CompletionEventListener; |
| 30 | +import ee.carlrobert.llm.completion.CompletionEventSourceListener; |
24 | 31 | import ee.carlrobert.llm.completion.CompletionRequest; |
25 | 32 | import java.io.IOException; |
26 | 33 | import java.util.Collection; |
|
30 | 37 | import java.util.stream.Stream; |
31 | 38 | import okhttp3.Call; |
32 | 39 | import okhttp3.Callback; |
| 40 | +import okhttp3.HttpUrl; |
| 41 | +import okhttp3.MediaType; |
33 | 42 | import okhttp3.Request; |
34 | 43 | import okhttp3.RequestBody; |
35 | 44 | import okhttp3.Response; |
|
41 | 50 | @Service |
42 | 51 | public final class CompletionRequestService { |
43 | 52 |
|
| 53 | + private static final String GOOGLE_BASE_URL = |
| 54 | + "https://generativelanguage.googleapis.com"; |
| 55 | + private static final MediaType JSON_MEDIA_TYPE = MediaType.parse("application/json"); |
| 56 | + |
44 | 57 | private CompletionRequestService() { |
45 | 58 | } |
46 | 59 |
|
@@ -240,10 +253,12 @@ public EventSource getChatCompletionAsync( |
240 | 253 | eventListener); |
241 | 254 | } |
242 | 255 | if (request instanceof GoogleCompletionRequest completionRequest) { |
| 256 | + var model = ModelSelectionService.getInstance().getModelForFeature(featureType, null); |
| 257 | + if (model != null && GoogleModel.findByCode(model) == null) { |
| 258 | + return getGoogleNonEnumModelCompletionAsync(completionRequest, model, eventListener); |
| 259 | + } |
243 | 260 | return CompletionClientProvider.getGoogleClient().getChatCompletionAsync( |
244 | | - completionRequest, |
245 | | - ModelSelectionService.getInstance().getModelForFeature(featureType, null), |
246 | | - eventListener); |
| 261 | + completionRequest, model, eventListener); |
247 | 262 | } |
248 | 263 |
|
249 | 264 | throw new IllegalStateException("Unknown request type: " + request.getClass()); |
@@ -293,11 +308,14 @@ public String getChatCompletion(CompletionRequest request, ServiceType serviceTy |
293 | 308 | .getText(); |
294 | 309 | } |
295 | 310 | if (request instanceof GoogleCompletionRequest completionRequest) { |
| 311 | + var model = ApplicationManager.getApplication() |
| 312 | + .getService(ModelSelectionService.class) |
| 313 | + .getModelForFeature(featureType, null); |
| 314 | + if (model != null && GoogleModel.findByCode(model) == null) { |
| 315 | + return getGoogleNonEnumModelCompletion(completionRequest, model); |
| 316 | + } |
296 | 317 | return CompletionClientProvider.getGoogleClient().getChatCompletion( |
297 | | - completionRequest, |
298 | | - ApplicationManager.getApplication() |
299 | | - .getService(ModelSelectionService.class) |
300 | | - .getModelForFeature(featureType, null)) |
| 318 | + completionRequest, model) |
301 | 319 | .getCandidates().get(0) |
302 | 320 | .getContent().getParts().get(0) |
303 | 321 | .getText(); |
@@ -333,6 +351,117 @@ public static boolean isRequestAllowed(ServiceType serviceType) { |
333 | 351 | }; |
334 | 352 | } |
335 | 353 |
|
| 354 | + private EventSource getGoogleNonEnumModelCompletionAsync( |
| 355 | + GoogleCompletionRequest request, |
| 356 | + String model, |
| 357 | + CompletionEventListener<String> eventListener) { |
| 358 | + try { |
| 359 | + var httpRequest = buildGoogleNonEnumRequest(model, "streamGenerateContent", request, true); |
| 360 | + var httpClient = CompletionClientProvider.getDefaultClientBuilder().build(); |
| 361 | + return EventSources.createFactory(httpClient).newEventSource( |
| 362 | + httpRequest, createGoogleEventSourceListener(eventListener)); |
| 363 | + } catch (JsonProcessingException e) { |
| 364 | + throw new RuntimeException("Failed to serialize Google completion request", e); |
| 365 | + } |
| 366 | + } |
| 367 | + |
| 368 | + private String getGoogleNonEnumModelCompletion( |
| 369 | + GoogleCompletionRequest request, |
| 370 | + String model) { |
| 371 | + try { |
| 372 | + var httpRequest = buildGoogleNonEnumRequest(model, "generateContent", request, false); |
| 373 | + var httpClient = CompletionClientProvider.getDefaultClientBuilder().build(); |
| 374 | + try (var response = httpClient.newCall(httpRequest).execute()) { |
| 375 | + return DeserializationUtil.mapResponse(response, GoogleCompletionResponse.class) |
| 376 | + .getCandidates().get(0) |
| 377 | + .getContent().getParts().get(0) |
| 378 | + .getText(); |
| 379 | + } |
| 380 | + } catch (IOException e) { |
| 381 | + throw new RuntimeException("Failed to get Google completion", e); |
| 382 | + } |
| 383 | + } |
| 384 | + |
| 385 | + private Request buildGoogleNonEnumRequest( |
| 386 | + String model, String action, Object requestBody, boolean stream) |
| 387 | + throws JsonProcessingException { |
| 388 | + var apiKey = CredentialsStore.INSTANCE.getCredential(CredentialKey.GoogleApiKey.INSTANCE); |
| 389 | + var urlBuilder = HttpUrl.parse( |
| 390 | + GOOGLE_BASE_URL + "/v1beta/models/" + model + ":" + action).newBuilder(); |
| 391 | + if (apiKey != null && !apiKey.isEmpty()) { |
| 392 | + urlBuilder.addQueryParameter("key", apiKey); |
| 393 | + } |
| 394 | + if (stream) { |
| 395 | + urlBuilder.addQueryParameter("alt", "sse"); |
| 396 | + } |
| 397 | + |
| 398 | + var mapper = DeserializationUtil.OBJECT_MAPPER; |
| 399 | + var jsonNode = (ObjectNode) mapper.valueToTree(requestBody); |
| 400 | + |
| 401 | + // Inject thinkingConfig for models that support thinking (3.x+) |
| 402 | + var genConfig = jsonNode.has("generationConfig") |
| 403 | + ? (ObjectNode) jsonNode.get("generationConfig") |
| 404 | + : mapper.createObjectNode(); |
| 405 | + if (!genConfig.has("thinkingConfig")) { |
| 406 | + var thinkingConfig = mapper.createObjectNode(); |
| 407 | + thinkingConfig.put("thinkingLevel", "low"); |
| 408 | + genConfig.set("thinkingConfig", thinkingConfig); |
| 409 | + } |
| 410 | + if (!jsonNode.has("generationConfig")) { |
| 411 | + jsonNode.set("generationConfig", genConfig); |
| 412 | + } |
| 413 | + |
| 414 | + return new Request.Builder() |
| 415 | + .url(urlBuilder.build()) |
| 416 | + .header("Cache-Control", "no-cache") |
| 417 | + .header("Content-Type", "application/json") |
| 418 | + .header("Accept", stream ? "text/event-stream" : "text/json") |
| 419 | + .post(RequestBody.create(mapper.writeValueAsString(jsonNode), JSON_MEDIA_TYPE)) |
| 420 | + .build(); |
| 421 | + } |
| 422 | + |
| 423 | + private CompletionEventSourceListener<String> createGoogleEventSourceListener( |
| 424 | + CompletionEventListener<String> eventListener) { |
| 425 | + return new CompletionEventSourceListener<>(eventListener) { |
| 426 | + @Override |
| 427 | + protected String getMessage(String data) { |
| 428 | + try { |
| 429 | + var candidates = DeserializationUtil.OBJECT_MAPPER |
| 430 | + .readValue(data, GoogleCompletionResponse.class) |
| 431 | + .getCandidates(); |
| 432 | + return (candidates == null |
| 433 | + ? Stream.<GoogleCompletionResponse.Candidate>empty() |
| 434 | + : candidates.stream()) |
| 435 | + .filter(Objects::nonNull) |
| 436 | + .flatMap(candidate -> { |
| 437 | + if (candidate.getContent() != null |
| 438 | + && candidate.getContent().getParts() != null) { |
| 439 | + return candidate.getContent().getParts().stream(); |
| 440 | + } |
| 441 | + return Stream.empty(); |
| 442 | + }) |
| 443 | + .filter(Objects::nonNull) |
| 444 | + .filter(part -> part.getThought() == null || !part.getThought()) |
| 445 | + .findFirst() |
| 446 | + .map(GoogleContentPart::getText) |
| 447 | + .orElse(""); |
| 448 | + } catch (JsonProcessingException e) { |
| 449 | + // ignore |
| 450 | + } |
| 451 | + return ""; |
| 452 | + } |
| 453 | + |
| 454 | + @Override |
| 455 | + protected ErrorDetails getErrorDetails(String data) throws JsonProcessingException { |
| 456 | + var googleError = DeserializationUtil.OBJECT_MAPPER |
| 457 | + .readValue(data, ApiResponseError.class).getError(); |
| 458 | + return googleError == null ? null |
| 459 | + : new ErrorDetails(googleError.getMessage(), googleError.getStatus(), null, |
| 460 | + googleError.getCode()); |
| 461 | + } |
| 462 | + }; |
| 463 | + } |
| 464 | + |
336 | 465 | /** |
337 | 466 | * Content of the first choice. |
338 | 467 | * <ul> |
|
0 commit comments