Skip to content

Commit d2e9c1e

Browse files
committed
add gemini support
1 parent c582a5b commit d2e9c1e

File tree

6 files changed

+106
-5
lines changed

6 files changed

+106
-5
lines changed

lua/codecompanion/adapters/gemini.lua

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ return {
1010
},
1111
opts = {
1212
stream = true,
13+
tools = true,
1314
},
1415
features = {
1516
text = true,
@@ -35,11 +36,14 @@ return {
3536
form_parameters = function(self, params, messages)
3637
return openai.handlers.form_parameters(self, params, messages)
3738
end,
39+
form_tools = function(self, tools)
40+
return openai.handlers.form_tools(self, tools)
41+
end,
3842
form_messages = function(self, messages)
3943
return openai.handlers.form_messages(self, messages)
4044
end,
41-
chat_output = function(self, data)
42-
return openai.handlers.chat_output(self, data)
45+
chat_output = function(self, data, tools)
46+
return openai.handlers.chat_output(self, data, tools)
4347
end,
4448
inline_output = function(self, data, context)
4549
return openai.handlers.inline_output(self, data, context)

lua/codecompanion/adapters/openai.lua

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ return {
148148
end
149149

150150
if self.opts.tools and delta.tool_calls and tools then
151-
for _, tool in ipairs(delta.tool_calls) do
151+
for i, tool in ipairs(delta.tool_calls) do
152152
if self.opts.stream then
153-
local index = tostring(tool.index)
153+
local index = tool.index and tostring(tool.index) or tostring(i)
154154
if not vim.tbl_contains(vim.tbl_keys(tools), index) then
155155
tools[index] = {
156156
name = tool["function"]["name"],
@@ -159,7 +159,8 @@ return {
159159
end
160160
tools[index]["arguments"] = (tools[index]["arguments"] or "") .. (tool["function"]["arguments"] or "")
161161
else
162-
tools[tool.id] = {
162+
local id = (tool.id and tool.id ~= "") and tostring(tool.id) or tostring(i)
163+
tools[id] = {
163164
name = tool["function"]["name"],
164165
arguments = vim.json.decode(tool["function"]["arguments"]),
165166
}

lua/codecompanion/strategies/chat/agents/tools/weather.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ return {
1111
}
1212
end,
1313
},
14+
system_prompt = "The weather tool must only be called once for each location",
1415
schema = {
1516
type = "function",
1617
["function"] = {
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{
2+
"choices": [
3+
{
4+
"finish_reason": "tool_calls",
5+
"index": 0,
6+
"message": {
7+
"role": "assistant",
8+
"tool_calls": [
9+
{
10+
"function": {
11+
"arguments": "{\"location\":\"London, UK\",\"units\":\"celsius\"}",
12+
"name": "weather"
13+
},
14+
"id": "",
15+
"type": "function"
16+
},
17+
{
18+
"function": {
19+
"arguments": "{\"units\":\"celsius\",\"location\":\"Paris, France\"}",
20+
"name": "weather"
21+
},
22+
"id": "",
23+
"type": "function"
24+
}
25+
]
26+
}
27+
}
28+
],
29+
"created": 1743631193,
30+
"model": "gemini-2.0-flash",
31+
"object": "chat.completion",
32+
"usage": {
33+
"completion_tokens": 16,
34+
"prompt_tokens": 62,
35+
"total_tokens": 78
36+
}
37+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
data: {"choices":[{"delta":{"role":"assistant","tool_calls":[{"function":{"arguments":"{\"units\":\"celsius\",\"location\":\"London\"}","name":"weather"},"id":"","type":"function"},{"function":{"arguments":"{\"units\":\"celsius\",\"location\":\"Paris\"}","name":"weather"},"id":"","type":"function"}]},"finish_reason":"tool_calls","index":0}],"created":1743628522,"model":"gemini-2.5-pro-exp-03-25","object":"chat.completion.chunk","usage":{"completion_tokens":38,"prompt_tokens":504,"total_tokens":542}}

tests/adapters/test_gemini.lua

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ T["Gemini adapter"]["can form messages to be sent to the API"] = function()
4848
h.eq(output, adapter.handlers.form_messages(adapter, messages))
4949
end
5050

51+
T["Gemini adapter"]["it can form tools to be sent to the API"] = function()
52+
local weather = require("tests/strategies/chat/agents/tools/stubs/weather").schema
53+
local tools = { weather = { weather } }
54+
55+
h.eq({ tools = { weather } }, adapter.handlers.form_tools(adapter, tools))
56+
end
57+
5158
T["Gemini adapter"]["Streaming"] = new_set()
5259

5360
T["Gemini adapter"]["Streaming"]["can output streamed data into the chat buffer"] = function()
@@ -63,6 +70,27 @@ T["Gemini adapter"]["Streaming"]["can output streamed data into the chat buffer"
6370
h.expect_starts_with("Elegant, dynamic", output)
6471
end
6572

73+
T["Gemini adapter"]["Streaming"]["can process tools"] = function()
74+
local tools = {}
75+
local lines = vim.fn.readfile("tests/adapters/stubs/gemini_tools_streaming.txt")
76+
for _, line in ipairs(lines) do
77+
adapter.handlers.chat_output(adapter, line, tools)
78+
end
79+
80+
local tool_output = {
81+
["1"] = {
82+
arguments = '{"units":"celsius","location":"London"}',
83+
name = "weather",
84+
},
85+
["2"] = {
86+
arguments = '{"units":"celsius","location":"Paris"}',
87+
name = "weather",
88+
},
89+
}
90+
91+
h.eq(tool_output, tools)
92+
end
93+
6694
T["Gemini adapter"]["No Streaming"] = new_set({
6795
hooks = {
6896
pre_case = function()
@@ -85,6 +113,35 @@ T["Gemini adapter"]["No Streaming"]["can output for the chat buffer"] = function
85113
h.expect_starts_with("Elegant, dynamic.", adapter.handlers.chat_output(adapter, json).output.content)
86114
end
87115

116+
T["Gemini adapter"]["No Streaming"]["can process tools"] = function()
117+
local data = vim.fn.readfile("tests/adapters/stubs/gemini_tools_no_streaming.txt")
118+
data = table.concat(data, "\n")
119+
120+
local tools = {}
121+
122+
-- Match the format of the actual request
123+
local json = { body = data }
124+
adapter.handlers.chat_output(adapter, json, tools)
125+
126+
local tool_output = {
127+
["1"] = {
128+
arguments = {
129+
location = "London, UK",
130+
units = "celsius",
131+
},
132+
name = "weather",
133+
},
134+
["2"] = {
135+
arguments = {
136+
location = "Paris, France",
137+
units = "celsius",
138+
},
139+
name = "weather",
140+
},
141+
}
142+
h.eq(tool_output, tools)
143+
end
144+
88145
T["Gemini adapter"]["No Streaming"]["can output for the inline assistant"] = function()
89146
local data = vim.fn.readfile("tests/adapters/stubs/gemini_no_streaming.txt")
90147
data = table.concat(data, "\n")

0 commit comments

Comments
 (0)