Skip to content

Commit be55c9b

Browse files
authored
feat: adapter and debug window improvement
1 parent 6846482 commit be55c9b

26 files changed

+684
-290
lines changed

codecompanion-workspace.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,30 @@
8181
"path": "tests/helpers.lua"
8282
}
8383
]
84+
},
85+
{
86+
"name": "Adapters",
87+
"system_prompt": "In the CodeCompanion plugin, adapters are used to connect to LLMs. The adapters contain various options for the LLM's endpoint alongside a defined schema for properties such as the model, temperature, top k, top p etc. The adapters also contain various handler functions which define how messages which are sent to the LLM should be formatted alongside how output from the LLM should be received and displayed in the chat buffer. The adapters are defined in the `adapters` directory.",
88+
"opts": {
89+
"remove_config_system_prompt": true
90+
},
91+
"vars": {
92+
"base_dir": "lua/codecompanion"
93+
},
94+
"files": [
95+
{
96+
"description": "Each LLM has their own adapter. This allows for LLM settings to be generated from the schema table in an adapter before they're sent to the LLM via the http file. ",
97+
"path": "${base_dir}/adapters/init.lua"
98+
},
99+
{
100+
"description": "Adapters are then passed to the http client which sends requests to LLMs via Curl:",
101+
"path": "${base_dir}/http.lua"
102+
},
103+
{
104+
"description": "Adapters must follow a schema. The validation and how schema values are extracted from the table schema is defined in:",
105+
"path": "${base_dir}/schema.lua"
106+
}
107+
]
84108
}
85109
]
86110
}

doc/extending/adapters.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ Let's take a look at the interface of an adapter as per the `adapter.lua` file:
1818
---@field env_replaced? table Replacement of environment variables with their actual values
1919
---@field headers table The headers to pass to the request
2020
---@field parameters table The parameters to pass to the request
21+
---@field body table Additional body parameters to pass to the request
22+
---@field chat_prompt string The system chat prompt to send to the LLM
2123
---@field raw? table Any additional curl arguments to pass to the request
2224
---@field opts? table Additional options for the adapter
2325
---@field handlers table Functions which link the output from the request to CodeCompanion
@@ -448,4 +450,5 @@ temperature = {
448450
},
449451
```
450452
451-
You'll see we've specified a function call for the `condition` key. We're simply checking that the model name doesn't being with `o1` as these models don't accept temperature as a parameter. You'll also see we've specified a function call for the `validate` key. We're simply checking that the value of the temperature is between 0 and 2
453+
You'll see we've specified a function call for the `condition` key. We're simply checking that the model name doesn't being with `o1` as these models don't accept temperature as a parameter. You'll also see we've specified a function call for the `validate` key. We're simply checking that the value of the temperature is between 0 and 2.
454+

lua/codecompanion/adapters/anthropic.lua

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ return {
1818
text = true,
1919
vision = true,
2020
},
21+
opts = {
22+
stream = true,
23+
cache_breakpoints = 4, -- Cache up to this many messages
24+
cache_over = 300, -- Cache any message which has this many tokens or more
25+
},
2126
url = "https://api.anthropic.com/v1/messages",
2227
env = {
2328
api_key = "ANTHROPIC_API_KEY",
@@ -28,15 +33,17 @@ return {
2833
["anthropic-version"] = "2023-06-01",
2934
["anthropic-beta"] = "prompt-caching-2024-07-31",
3035
},
31-
parameters = {
32-
stream = true,
33-
},
34-
opts = {
35-
stream = true, -- NOTE: Currently, CodeCompanion ONLY supports streaming with this adapter
36-
cache_breakpoints = 4, -- Cache up to this many messages
37-
cache_over = 300, -- Cache any message which has this many tokens or more
38-
},
3936
handlers = {
37+
---@param self CodeCompanion.Adapter
38+
---@return boolean
39+
setup = function(self)
40+
if self.opts and self.opts.stream then
41+
self.parameters.stream = true
42+
end
43+
44+
return true
45+
end,
46+
4047
---Set the parameters
4148
---@param self CodeCompanion.Adapter
4249
---@param params table
@@ -116,11 +123,15 @@ return {
116123

117124
---Returns the number of tokens generated from the LLM
118125
---@param self CodeCompanion.Adapter
119-
---@param data string The data from the LLM
126+
---@param data table The data from the LLM
120127
---@return number|nil
121128
tokens = function(self, data)
122129
if data then
123-
data = data:sub(6)
130+
if self.opts.stream then
131+
data = utils.clean_streamed_data(data)
132+
else
133+
data = data.body
134+
end
124135
local ok, json = pcall(vim.json.decode, data)
125136

126137
if ok then
@@ -129,29 +140,36 @@ return {
129140
+ (json.message.usage.cache_creation_input_tokens or 0)
130141

131142
output_tokens = json.message.usage.output_tokens or 0
132-
end
133-
if json.type == "message_delta" then
143+
elseif json.type == "message_delta" then
134144
return (input_tokens + output_tokens + json.usage.output_tokens)
145+
elseif json.type == "message" then
146+
return (json.usage.input_tokens + json.usage.output_tokens)
135147
end
136148
end
137149
end
138150
end,
139151

140152
---Output the data from the API ready for insertion into the chat buffer
141153
---@param self CodeCompanion.Adapter
142-
---@param data string The streamed JSON data from the API, also formatted by the format_data handler
154+
---@param data table The streamed JSON data from the API, also formatted by the format_data handler
143155
---@return table|nil
144156
chat_output = function(self, data)
145157
local output = {}
146158

147-
-- Skip the event messages
148-
if type(data) == "string" and string.sub(data, 1, 6) == "event:" then
149-
return
159+
if self.opts.stream then
160+
if type(data) == "string" and string.sub(data, 1, 6) == "event:" then
161+
return
162+
end
150163
end
151164

152165
if data and data ~= "" then
153-
local data_mod = utils.clean_streamed_data(data)
154-
local ok, json = pcall(vim.json.decode, data_mod, { luanil = { object = true } })
166+
if self.opts.stream then
167+
data = utils.clean_streamed_data(data)
168+
else
169+
data = data.body
170+
end
171+
172+
local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } })
155173

156174
if ok then
157175
if json.type == "message_start" then
@@ -160,6 +178,9 @@ return {
160178
elseif json.type == "content_block_delta" then
161179
output.role = nil
162180
output.content = json.delta.text
181+
elseif json.type == "message" then
182+
output.role = json.role
183+
output.content = json.content[1].text
163184
end
164185

165186
return {
@@ -176,17 +197,25 @@ return {
176197
---@param context table Useful context about the buffer to inline to
177198
---@return table|nil
178199
inline_output = function(self, data, context)
179-
if type(data) == "string" and string.sub(data, 1, 6) == "event:" then
180-
return
200+
if self.opts.stream then
201+
if type(data) == "string" and string.sub(data, 1, 6) == "event:" then
202+
return
203+
end
181204
end
182205

183206
if data and data ~= "" then
184-
data = data:sub(6)
207+
if self.opts.stream then
208+
data = utils.clean_streamed_data(data)
209+
else
210+
data = data.body
211+
end
185212
local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } })
186213

187214
if ok then
188215
if json.type == "content_block_delta" then
189216
return json.delta.text
217+
elseif json.type == "message" then
218+
return json.content[1].text
190219
end
191220
end
192221
end
@@ -257,7 +286,7 @@ return {
257286
default = nil,
258287
desc = "Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses",
259288
validate = function(n)
260-
return n >= 0 and n <= 500, "Must be between 0 and 500"
289+
return n >= 0, "Must be greater than 0"
261290
end,
262291
},
263292
stop_sequences = {

lua/codecompanion/adapters/copilot.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ return {
254254
order = 4,
255255
mapping = "parameters",
256256
type = "integer",
257-
default = 4096,
257+
default = 15000,
258258
desc = "The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length.",
259259
},
260260
top_p = {

lua/codecompanion/adapters/deepseek.lua

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
local log = require("codecompanion.utils.log")
21
local openai = require("codecompanion.adapters.openai")
32
local utils = require("codecompanion.utils.adapters")
43

@@ -90,15 +89,12 @@ return {
9089
if delta.role then
9190
output.role = delta.role
9291
end
93-
9492
if self.opts.can_reason and delta.reasoning_content then
9593
output.reasoning = delta.reasoning_content
9694
end
97-
9895
if delta.content then
9996
output.content = (output.content or "") .. delta.content
10097
end
101-
10298
return {
10399
status = "success",
104100
output = output,

lua/codecompanion/adapters/gemini.lua

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
---https://github.com/google-gemini/cookbook/blob/main/quickstarts/rest/Streaming_REST.ipynb
33

44
local log = require("codecompanion.utils.log")
5+
local utils = require("codecompanion.utils.adapters")
56

67
---@class Gemini.Adapter: CodeCompanion.Adapter
78
return {
@@ -12,17 +13,25 @@ return {
1213
user = "user",
1314
},
1415
opts = {
15-
stream = true, -- NOTE: Currently, CodeCompanion ONLY supports streaming with this adapter
16+
stream = true,
1617
},
1718
features = {
1819
tokens = true,
1920
text = true,
2021
vision = true,
2122
},
22-
url = "https://generativelanguage.googleapis.com/v1beta/models/${model}:streamGenerateContent?alt=sse&key=${api_key}",
23+
url = "https://generativelanguage.googleapis.com/v1beta/models/${model}${stream}key=${api_key}",
2324
env = {
2425
api_key = "GEMINI_API_KEY",
2526
model = "schema.model.default",
27+
stream = function(self)
28+
local stream = ":generateContent?"
29+
if self.opts.stream then
30+
-- NOTE: With sse each stream chunk is a GenerateContentResponse object with a portion of the output text in candidates[0].content.parts[0].text
31+
stream = ":streamGenerateContent?alt=sse&"
32+
end
33+
return stream
34+
end,
2635
},
2736
headers = {
2837
["Content-Type"] = "application/json",
@@ -93,7 +102,7 @@ return {
93102
---@return number|nil
94103
tokens = function(self, data)
95104
if data and data ~= "" then
96-
data = data:sub(6)
105+
data = utils.clean_streamed_data(data)
97106
local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } })
98107

99108
if ok then
@@ -110,7 +119,7 @@ return {
110119
local output = {}
111120

112121
if data and data ~= "" then
113-
data = data:sub(6)
122+
data = utils.clean_streamed_data(data)
114123
local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } })
115124

116125
if ok and json.candidates[1].content then
@@ -132,7 +141,7 @@ return {
132141
---@return table|nil
133142
inline_output = function(self, data, context)
134143
if data and data ~= "" then
135-
data = data:sub(6)
144+
data = utils.clean_streamed_data(data)
136145
local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } })
137146

138147
if not ok then
@@ -176,5 +185,65 @@ return {
176185
"gemini-1.0-pro",
177186
},
178187
},
188+
maxOutputTokens = {
189+
order = 2,
190+
mapping = "body.generationConfig",
191+
type = "integer",
192+
optional = true,
193+
default = nil,
194+
desc = "The maximum number of tokens to include in a response candidate. Note: The default value varies by model",
195+
validate = function(n)
196+
return n > 0, "Must be greater than 0"
197+
end,
198+
},
199+
temperature = {
200+
order = 3,
201+
mapping = "body.generationConfig",
202+
type = "number",
203+
optional = true,
204+
default = nil,
205+
desc = "Controls the randomness of the output.",
206+
validate = function(n)
207+
return n >= 0 and n <= 2, "Must be between 0 and 2"
208+
end,
209+
},
210+
topP = {
211+
order = 4,
212+
mapping = "body.generationConfig",
213+
type = "integer",
214+
optional = true,
215+
default = nil,
216+
desc = "The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and Top-p (nucleus) sampling. Tokens are sorted based on their assigned probabilities so that only the most likely tokens are considered. Top-k sampling directly limits the maximum number of tokens to consider, while Nucleus sampling limits the number of tokens based on the cumulative probability.",
217+
validate = function(n)
218+
return n > 0, "Must be greater than 0"
219+
end,
220+
},
221+
topK = {
222+
order = 5,
223+
mapping = "body.generationConfig",
224+
type = "integer",
225+
optional = true,
226+
default = nil,
227+
desc = "The maximum number of tokens to consider when sampling",
228+
validate = function(n)
229+
return n > 0, "Must be greater than 0"
230+
end,
231+
},
232+
presencePenalty = {
233+
order = 6,
234+
mapping = "body.generationConfig",
235+
type = "number",
236+
optional = true,
237+
default = nil,
238+
desc = "Presence penalty applied to the next token's logprobs if the token has already been seen in the response",
239+
},
240+
frequencyPenalty = {
241+
order = 7,
242+
mapping = "body.generationConfig",
243+
type = "number",
244+
optional = true,
245+
default = nil,
246+
desc = "Frequency penalty applied to the next token's logprobs, multiplied by the number of times each token has been seen in the response so far.",
247+
},
179248
},
180249
}

0 commit comments

Comments
 (0)