Skip to content
This repository was archived by the owner on Jul 22, 2025. It is now read-only.

Commit c5d1b7b

Browse files
committed
Gemini support for new interface
1 parent 540b6a7 commit c5d1b7b

File tree

3 files changed

+143
-47
lines changed

3 files changed

+143
-47
lines changed

lib/completions/endpoints/cohere.rb

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,8 @@ def extract_completion_from(response_raw)
7777
end
7878
end
7979

80-
def has_tool?(_ignored)
81-
@has_tool
82-
end
83-
84-
def native_tool_support?
85-
true
80+
def xml_tools_enabled?
81+
false
8682
end
8783

8884
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)

lib/completions/endpoints/gemini.rb

Lines changed: 68 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def chunk_to_string(chunk)
111111
chunk.to_s
112112
end
113113

114-
class Decoder
114+
class GeminiStreamingDecoder
115115
def initialize
116116
@buffer = +""
117117
end
@@ -151,44 +151,86 @@ def decode(str)
151151
end
152152

153153
def decode(chunk)
154-
@decoder ||= Decoder.new
155-
@decoder.decode(chunk)
154+
json = JSON.parse(chunk, symbolize_names: true)
155+
idx = -1
156+
json.dig(:candidates, 0, :content, :parts).map do |part|
157+
if part[:functionCall]
158+
idx += 1
159+
ToolCall.new(
160+
id: "tool_#{idx}",
161+
name: part[:functionCall][:name],
162+
parameters: part[:functionCall][:args],
163+
)
164+
else
165+
part = part[:text]
166+
if part != ""
167+
part
168+
else
169+
nil
170+
end
171+
end
172+
end
156173
end
157174

158-
def extract_prompt_for_tokenizer(prompt)
159-
prompt.to_s
160-
end
175+
def decode_chunk(chunk)
176+
@tool_index ||= -1
161177

162-
def has_tool?(_response_data)
163-
@has_function_call
178+
streaming_decoder.decode(chunk).map do |parsed|
179+
update_usage(parsed)
180+
parsed.dig(:candidates, 0, :content, :parts).map do |part|
181+
if part[:text]
182+
part = part[:text]
183+
if part != ""
184+
part
185+
else
186+
nil
187+
end
188+
elsif part[:functionCall]
189+
@tool_index += 1
190+
ToolCall.new(
191+
id: "tool_#{@tool_index}",
192+
name: part[:functionCall][:name],
193+
parameters: part[:functionCall][:args],
194+
)
195+
end
196+
end
197+
end.flatten.compact
164198
end
165199

166-
def native_tool_support?
167-
true
200+
def update_usage(parsed)
201+
usage = parsed.dig(:usageMetadata)
202+
if usage
203+
if prompt_token_count = usage[:promptTokenCount]
204+
@prompt_token_count = prompt_token_count
205+
end
206+
if candidate_token_count = usage[:candidatesTokenCount]
207+
@candidate_token_count = candidate_token_count
208+
end
209+
end
168210
end
169211

170-
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
171-
if @streaming_mode
172-
return function_buffer if !partial
173-
else
174-
partial = payload
212+
def final_log_update(log)
213+
if @prompt_token_count
214+
log.request_tokens = @prompt_token_count
175215
end
176216

177-
function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
217+
if @candidate_token_count
218+
log.response_tokens = @candidate_token_count
219+
end
220+
end
178221

179-
if partial[:args]
180-
argument_fragments =
181-
partial[:args].reduce(+"") do |memo, (arg_name, value)|
182-
memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}</#{arg_name}>"
183-
end
184-
argument_fragments << "\n"
222+
def streaming_decoder
223+
@decoder ||= GeminiStreamingDecoder.new
224+
end
185225

186-
function_buffer.at("parameters").children =
187-
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
188-
end
226+
def extract_prompt_for_tokenizer(prompt)
227+
prompt.to_s
228+
end
189229

190-
function_buffer
230+
def xml_tools_enabled?
231+
false
191232
end
233+
192234
end
193235
end
194236
end

spec/lib/completions/endpoints/gemini_spec.rb

Lines changed: 73 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -195,19 +195,16 @@ def tool_response
195195

196196
response = llm.generate(prompt, user: user)
197197

198-
expected = (<<~XML).strip
199-
<function_calls>
200-
<invoke>
201-
<tool_name>echo</tool_name>
202-
<parameters>
203-
<text>&lt;S&gt;ydney</text>
204-
</parameters>
205-
<tool_id>tool_0</tool_id>
206-
</invoke>
207-
</function_calls>
208-
XML
209-
210-
expect(response.strip).to eq(expected)
198+
tool =
199+
DiscourseAi::Completions::ToolCall.new(
200+
id: "tool_0",
201+
name: "echo",
202+
parameters: {
203+
text: "<S>ydney",
204+
},
205+
)
206+
207+
expect(response).to eq(tool)
211208
end
212209

213210
it "Supports Vision API" do
@@ -265,6 +262,67 @@ def tool_response
265262
expect(JSON.parse(req_body)).to eq(expected_prompt)
266263
end
267264

265+
it "Can stream tool calls correctly" do
266+
rows = [
267+
{
268+
candidates: [
269+
{
270+
content: {
271+
parts: [{ functionCall: { name: "echo", args: { text: "sam<>wh!s" } } }],
272+
role: "model",
273+
},
274+
safetyRatings: [
275+
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
276+
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
277+
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
278+
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
279+
],
280+
},
281+
],
282+
usageMetadata: {
283+
promptTokenCount: 625,
284+
totalTokenCount: 625,
285+
},
286+
modelVersion: "gemini-1.5-pro-002",
287+
},
288+
{
289+
candidates: [{ content: { parts: [{ text: "" }], role: "model" }, finishReason: "STOP" }],
290+
usageMetadata: {
291+
promptTokenCount: 625,
292+
candidatesTokenCount: 4,
293+
totalTokenCount: 629,
294+
},
295+
modelVersion: "gemini-1.5-pro-002",
296+
},
297+
]
298+
299+
payload = rows.map { |r| "data: #{r.to_json}\n\n" }.join
300+
301+
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
302+
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
303+
304+
prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool])
305+
306+
output = []
307+
308+
stub_request(:post, url).to_return(status: 200, body: payload)
309+
llm.generate(prompt, user: user) { |partial| output << partial }
310+
311+
tool_call = DiscourseAi::Completions::ToolCall.new(
312+
id: "tool_0",
313+
name: "echo",
314+
parameters: {
315+
text: "sam<>wh!s",
316+
},
317+
)
318+
319+
expect(output).to eq([tool_call])
320+
321+
log = AiApiAuditLog.order(:id).last
322+
expect(log.request_tokens).to eq(625)
323+
expect(log.response_tokens).to eq(4)
324+
end
325+
268326
it "Can correctly handle streamed responses even if they are chunked badly" do
269327
data = +""
270328
data << "da|ta: |"
@@ -279,12 +337,12 @@ def tool_response
279337
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
280338
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
281339

282-
output = +""
340+
output = []
283341
gemini_mock.with_chunk_array_support do
284342
stub_request(:post, url).to_return(status: 200, body: split)
285343
llm.generate("Hello", user: user) { |partial| output << partial }
286344
end
287345

288-
expect(output).to eq("Hello World Sam")
346+
expect(output.join).to eq("Hello World Sam")
289347
end
290348
end

0 commit comments

Comments
 (0)