Skip to content

Commit 96570e6

Browse files
authored
refactor(adapter): DeepSeek is more aligned with OpenAI (#707)
1 parent 792a297 commit 96570e6

File tree

1 file changed

+16
-105
lines changed

1 file changed

+16
-105
lines changed

lua/codecompanion/adapters/deepseek.lua

Lines changed: 16 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
local log = require("codecompanion.utils.log")
2+
local openai = require("codecompanion.adapters.openai")
23

34
---Prepare data to be parsed as JSON
45
---@param data string | { body: string }
@@ -51,13 +52,12 @@ return {
5152
return true
5253
end,
5354

54-
---Set the parameters
55-
---@param self CodeCompanion.Adapter
56-
---@param params table
57-
---@param messages table
58-
---@return table
55+
--- Use the OpenAI adapter for the bulk of the work
56+
tokens = function(self, data)
57+
return openai.handlers.tokens(self, data)
58+
end,
5959
form_parameters = function(self, params, messages)
60-
return params
60+
return openai.handlers.form_parameters(self, params, messages)
6161
end,
6262

6363
---Set the format of the role and content for the messages from the chat buffer
@@ -80,7 +80,7 @@ return {
8080
else
8181
table.insert(processed, {
8282
role = msg.role,
83-
content = msg.content
83+
content = msg.content,
8484
})
8585
end
8686
end
@@ -109,97 +109,14 @@ return {
109109
return { messages = processed }
110110
end,
111111

112-
113-
---Returns the number of tokens generated from the LLM
114-
---@param self CodeCompanion.Adapter
115-
---@param data table The data from the LLM
116-
---@return number|nil
117-
tokens = function(self, data)
118-
if data and data ~= "" then
119-
local data_mod = prepare_data_for_json(data)
120-
local ok, json = pcall(vim.json.decode, data_mod, { luanil = { object = true } })
121-
122-
if ok then
123-
if json.usage then
124-
local tokens = json.usage.total_tokens
125-
log:trace("Tokens: %s", tokens)
126-
return tokens
127-
end
128-
end
129-
end
130-
end,
131-
132-
---Output the data from the API ready for insertion into the chat buffer
133-
---@param self CodeCompanion.Adapter
134-
---@param data table The streamed JSON data from the API, also formatted by the format_data handler
135-
---@return table|nil [status: string, output: table]
136112
chat_output = function(self, data)
137-
local output = {}
138-
139-
if data and data ~= "" then
140-
local data_mod = prepare_data_for_json(data)
141-
local ok, json = pcall(vim.json.decode, data_mod, { luanil = { object = true } })
142-
143-
if ok and json.choices and #json.choices > 0 then
144-
local choice = json.choices[1]
145-
local delta = (self.opts and self.opts.stream) and choice.delta or choice.message
146-
147-
if delta then
148-
if delta.role then
149-
output.role = delta.role
150-
else
151-
output.role = nil
152-
end
153-
154-
-- Some providers may return empty content
155-
if delta.content then
156-
output.content = delta.content
157-
else
158-
output.content = ""
159-
end
160-
161-
return {
162-
status = "success",
163-
output = output,
164-
}
165-
end
166-
end
167-
end
113+
return openai.handlers.chat_output(self, data)
168114
end,
169-
170-
---Output the data from the API ready for inlining into the current buffer
171-
---@param self CodeCompanion.Adapter
172-
---@param data string|table The streamed JSON data from the API, also formatted by the format_data handler
173-
---@param context table Useful context about the buffer to inline to
174-
---@return string|table|nil
175115
inline_output = function(self, data, context)
176-
if data and data ~= "" then
177-
data = prepare_data_for_json(data)
178-
local ok, json = pcall(vim.json.decode, data, { luanil = { object = true } })
179-
180-
if ok then
181-
--- Some third-party OpenAI forwarding services may have a return package with an empty json.choices.
182-
if not json.choices or #json.choices == 0 then
183-
return
184-
end
185-
186-
local choice = json.choices[1]
187-
local delta = (self.opts and self.opts.stream) and choice.delta or choice.message
188-
if delta.content then
189-
return delta.content
190-
end
191-
end
192-
end
116+
return openai.handlers.inline_output(self, data, context)
193117
end,
194-
195-
---Function to run when the request has completed. Useful to catch errors
196-
---@param self CodeCompanion.Adapter
197-
---@param data table
198-
---@return nil
199118
on_exit = function(self, data)
200-
if data.status >= 400 then
201-
log:error("Error: %s", data.body)
202-
end
119+
return openai.handlers.on_exit(self, data)
203120
end,
204121
},
205122
schema = {
@@ -232,8 +149,7 @@ return {
232149
type = "number",
233150
optional = true,
234151
default = 0.95,
235-
desc =
236-
"An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. Not used for R1.",
152+
desc = "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. Not used for R1.",
237153
validate = function(n)
238154
return n >= 0 and n <= 1, "Must be between 0 and 1"
239155
end,
@@ -258,8 +174,7 @@ return {
258174
type = "integer",
259175
optional = true,
260176
default = 8192,
261-
desc =
262-
"The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length.",
177+
desc = "The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length.",
263178
validate = function(n)
264179
return n > 0, "Must be greater than 0"
265180
end,
@@ -270,8 +185,7 @@ return {
270185
type = "number",
271186
optional = true,
272187
default = 0,
273-
desc =
274-
"Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. Not used for R1",
188+
desc = "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. Not used for R1",
275189
validate = function(n)
276190
return n >= -2 and n <= 2, "Must be between -2 and 2"
277191
end,
@@ -282,8 +196,7 @@ return {
282196
type = "number",
283197
optional = true,
284198
default = 0,
285-
desc =
286-
"Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. Not used for R1, but may be specified.",
199+
desc = "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. Not used for R1, but may be specified.",
287200
validate = function(n)
288201
return n >= -2 and n <= 2, "Must be between -2 and 2"
289202
end,
@@ -294,8 +207,7 @@ return {
294207
type = "boolean",
295208
optional = true,
296209
default = nil,
297-
desc =
298-
"Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. Not supported for R1.",
210+
desc = "Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. Not supported for R1.",
299211
subtype_key = {
300212
type = "integer",
301213
},
@@ -306,8 +218,7 @@ return {
306218
type = "string",
307219
optional = true,
308220
default = nil,
309-
desc =
310-
"A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Learn more.",
221+
desc = "A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Learn more.",
311222
validate = function(u)
312223
return u:len() < 100, "Cannot be longer than 100 characters"
313224
end,

0 commit comments

Comments
 (0)