Skip to content

Commit 947d8a2

Browse files
committed
custom openai responses API support
1 parent b501ffa commit 947d8a2

File tree

9 files changed

+420
-28
lines changed

9 files changed

+420
-28
lines changed

src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ public EventSource getCustomOpenAIChatCompletionAsync(
7070
new OpenAIChatCompletionEventSourceListener(eventListener));
7171
}
7272

73+
public EventSource getCustomOpenAIResponsesApiAsync(
74+
Request customRequest,
75+
CompletionEventListener<String> eventListener) {
76+
var httpClient = CompletionClientProvider.getDefaultClientBuilder().build();
77+
return EventSources.createFactory(httpClient).newEventSource(
78+
customRequest,
79+
new ResponsesApiEventSourceListener(eventListener));
80+
}
81+
7382
private EventSource getCustomOpenAINonStreamingChatCompletionAsync(
7483
Request customRequest,
7584
CompletionEventListener<String> eventListener) {
@@ -220,6 +229,9 @@ public EventSource getChatCompletionAsync(
220229
.getChatCompletionAsync(completionRequest, eventListener);
221230
}
222231
if (request instanceof CustomOpenAIRequest completionRequest) {
232+
if (completionRequest.isResponsesApi()) {
233+
return getCustomOpenAIResponsesApiAsync(completionRequest.getRequest(), eventListener);
234+
}
223235
return getCustomOpenAIChatCompletionAsync(completionRequest.getRequest(), eventListener);
224236
}
225237
if (request instanceof ClaudeCompletionRequest completionRequest) {
@@ -263,6 +275,9 @@ public String getChatCompletion(CompletionRequest request, ServiceType serviceTy
263275
if (request instanceof CustomOpenAIRequest completionRequest) {
264276
var httpClient = CompletionClientProvider.getDefaultClientBuilder().build();
265277
try (var response = httpClient.newCall(completionRequest.getRequest()).execute()) {
278+
if (completionRequest.isResponsesApi()) {
279+
return ResponsesApiResponseParser.extractContent(response);
280+
}
266281
return DeserializationUtil.mapResponse(response, OpenAIChatCompletionResponse.class)
267282
.getChoices().get(0)
268283
.getMessage()
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
package ee.carlrobert.codegpt.completions
2+
3+
import com.fasterxml.jackson.databind.ObjectMapper
4+
import com.intellij.openapi.diagnostic.Logger
5+
import ee.carlrobert.llm.client.openai.completion.ErrorDetails
6+
import ee.carlrobert.llm.client.openai.completion.response.ToolCall
7+
import ee.carlrobert.llm.client.openai.completion.response.ToolFunctionResponse
8+
import ee.carlrobert.llm.completion.CompletionEventListener
9+
import okhttp3.Response
10+
import okhttp3.internal.http2.StreamResetException
11+
import okhttp3.sse.EventSource
12+
import okhttp3.sse.EventSourceListener
13+
14+
class ResponsesApiEventSourceListener(
15+
private val listener: CompletionEventListener<String>
16+
) : EventSourceListener() {
17+
18+
companion object {
19+
private val LOG = Logger.getInstance(ResponsesApiEventSourceListener::class.java)
20+
private val OBJECT_MAPPER = ObjectMapper()
21+
}
22+
23+
private val messageBuilder = StringBuilder()
24+
25+
private data class FunctionCallState(
26+
var callId: String? = null,
27+
var name: String? = null,
28+
val arguments: StringBuilder = StringBuilder(),
29+
var index: Int = 0
30+
)
31+
32+
private val activeFunctionCalls = mutableMapOf<Int, FunctionCallState>()
33+
private var nextOutputIndex = 0
34+
35+
override fun onOpen(eventSource: EventSource, response: Response) {
36+
listener.onOpen()
37+
}
38+
39+
override fun onEvent(
40+
eventSource: EventSource,
41+
id: String?,
42+
type: String?,
43+
data: String
44+
) {
45+
try {
46+
when (type) {
47+
"response.output_text.delta" -> handleTextDelta(data, eventSource)
48+
"response.output_item.added" -> handleOutputItemAdded(data)
49+
"response.function_call_arguments.delta" -> handleFunctionCallArgsDelta(data)
50+
"response.output_item.done" -> handleOutputItemDone(data)
51+
"error" -> handleError(data)
52+
"response.created",
53+
"response.in_progress",
54+
"response.output_text.done",
55+
"response.content_part.added",
56+
"response.content_part.done",
57+
"response.completed" -> {
58+
// Informational events - no action needed
59+
}
60+
61+
else -> {
62+
LOG.debug("Unhandled Responses API event type: $type")
63+
listener.onEvent(data)
64+
}
65+
}
66+
} catch (e: Exception) {
67+
LOG.error("Error processing Responses API event type=$type", e)
68+
}
69+
}
70+
71+
override fun onClosed(eventSource: EventSource) {
72+
listener.onComplete(messageBuilder)
73+
}
74+
75+
override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) {
76+
if (t is StreamResetException
77+
|| (t is java.net.SocketException && t.message == "Socket closed")
78+
|| (t is java.io.IOException && t.message.equals("canceled", ignoreCase = true))
79+
) {
80+
listener.onCancelled(messageBuilder)
81+
return
82+
}
83+
84+
val errorDetails = try {
85+
if (response?.body != null) {
86+
val jsonBody = response.body!!.string()
87+
val node = OBJECT_MAPPER.readTree(jsonBody)
88+
val errorNode = node.get("error")
89+
if (errorNode != null) {
90+
ErrorDetails(
91+
errorNode.get("message")?.asText() ?: "Unknown error",
92+
errorNode.get("type")?.asText(),
93+
null,
94+
errorNode.get("code")?.asText()
95+
)
96+
} else {
97+
ErrorDetails("Unknown error. Code: ${response.code}, Body: $jsonBody")
98+
}
99+
} else {
100+
ErrorDetails(t?.message ?: "Unknown error")
101+
}
102+
} catch (_: Exception) {
103+
ErrorDetails(t?.message ?: "Unknown error")
104+
}
105+
106+
listener.onError(errorDetails, t ?: RuntimeException(errorDetails.message))
107+
}
108+
109+
private fun handleTextDelta(data: String, eventSource: EventSource) {
110+
val node = OBJECT_MAPPER.readTree(data)
111+
val delta = node.get("delta")?.asText() ?: return
112+
if (delta.isNotEmpty()) {
113+
messageBuilder.append(delta)
114+
listener.onMessage(delta, eventSource)
115+
}
116+
}
117+
118+
private fun handleOutputItemAdded(data: String) {
119+
val node = OBJECT_MAPPER.readTree(data)
120+
val item = node.get("item") ?: return
121+
val itemType = item.get("type")?.asText() ?: return
122+
123+
if (itemType == "function_call") {
124+
val idx = item.get("output_index")?.asInt() ?: nextOutputIndex++
125+
val callId = item.get("call_id")?.asText()
126+
val name = item.get("name")?.asText()
127+
activeFunctionCalls[idx] = FunctionCallState(
128+
callId = callId,
129+
name = name,
130+
index = idx
131+
)
132+
}
133+
}
134+
135+
private fun handleFunctionCallArgsDelta(data: String) {
136+
val node = OBJECT_MAPPER.readTree(data)
137+
val delta = node.get("delta")?.asText() ?: return
138+
val idx = node.get("output_index")?.asInt() ?: return
139+
activeFunctionCalls[idx]?.arguments?.append(delta)
140+
}
141+
142+
private fun handleOutputItemDone(data: String) {
143+
val node = OBJECT_MAPPER.readTree(data)
144+
val item = node.get("item") ?: return
145+
val itemType = item.get("type")?.asText() ?: return
146+
147+
if (itemType == "function_call") {
148+
val callId = item.get("call_id")?.asText() ?: return
149+
val name = item.get("name")?.asText() ?: return
150+
val arguments = item.get("arguments")?.asText() ?: ""
151+
val idx = item.get("output_index")?.asInt()
152+
153+
if (idx != null) activeFunctionCalls.remove(idx)
154+
155+
val toolCall = ToolCall(
156+
idx,
157+
callId,
158+
"function",
159+
ToolFunctionResponse(name, arguments)
160+
)
161+
listener.onToolCall(toolCall)
162+
}
163+
}
164+
165+
private fun handleError(data: String) {
166+
val node = OBJECT_MAPPER.readTree(data)
167+
val message =
168+
node.get("message")?.asText() ?: node.get("error")?.asText() ?: "Unknown error"
169+
val code = node.get("code")?.asText()
170+
listener.onError(
171+
ErrorDetails(message, null, null, code),
172+
RuntimeException(message)
173+
)
174+
}
175+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package ee.carlrobert.codegpt.completions
2+
3+
import com.fasterxml.jackson.databind.ObjectMapper
4+
import okhttp3.Response
5+
import java.io.IOException
6+
7+
object ResponsesApiResponseParser {
8+
9+
private val OBJECT_MAPPER = ObjectMapper()
10+
11+
@JvmStatic
12+
@Throws(IOException::class)
13+
fun extractContent(response: Response): String {
14+
val body = response.body?.string()
15+
?: throw IOException("Empty response body")
16+
val root = OBJECT_MAPPER.readTree(body)
17+
val output = root.get("output")
18+
?: throw IOException("Missing 'output' field in Responses API response")
19+
20+
for (item in output) {
21+
if (item.get("type")?.asText() == "message") {
22+
val content = item.get("content") ?: continue
23+
for (part in content) {
24+
if (part.get("type")?.asText() == "output_text") {
25+
return part.get("text")?.asText() ?: ""
26+
}
27+
}
28+
}
29+
}
30+
31+
throw IOException("No text content found in Responses API response")
32+
}
33+
}

0 commit comments

Comments
 (0)