Skip to content

Commit 64303c6

Browse files
committed
Add can_reason to model as suggested
1 parent f6b1d4b commit 64303c6

File tree

1 file changed

+42
-24
lines changed

1 file changed

+42
-24
lines changed

lua/codecompanion/adapters/anthropic.lua

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ return {
2121
opts = {
2222
stream = true,
2323
cache_breakpoints = 4, -- Cache up to this many messages
24-
cache_over = 300, -- Cache any message which has this many tokens or more
24+
cache_over = 300, -- Cache any message which has this many tokens or more
2525
},
2626
url = "https://api.anthropic.com/v1/messages",
2727
env = {
@@ -81,18 +81,18 @@ return {
8181
form_messages = function(self, messages)
8282
-- Extract and format system messages
8383
local system = vim
84-
.iter(messages)
85-
:filter(function(msg)
86-
return msg.role == "system"
87-
end)
88-
:map(function(msg)
89-
return {
90-
type = "text",
91-
text = msg.content,
92-
cache_control = nil, -- To be set later if needed
93-
}
94-
end)
95-
:totable()
84+
.iter(messages)
85+
:filter(function(msg)
86+
return msg.role == "system"
87+
end)
88+
:map(function(msg)
89+
return {
90+
type = "text",
91+
text = msg.content,
92+
cache_control = nil, -- To be set later if needed
93+
}
94+
end)
95+
:totable()
9696
system = next(system) and system or nil
9797

9898
-- Remove system messages and merge user/assistant messages
@@ -114,9 +114,9 @@ return {
114114
for i = #messages, 1, -1 do
115115
local message = messages[i]
116116
if
117-
message.role == self.roles.user
118-
and tokens.calculate(message.content) >= self.opts.cache_over
119-
and breakpoints_used < self.opts.cache_breakpoints
117+
message.role == self.roles.user
118+
and tokens.calculate(message.content) >= self.opts.cache_over
119+
and breakpoints_used < self.opts.cache_breakpoints
120120
then
121121
message.content = {
122122
{
@@ -158,7 +158,7 @@ return {
158158
if ok then
159159
if json.type == "message_start" then
160160
input_tokens = (json.message.usage.input_tokens or 0)
161-
+ (json.message.usage.cache_creation_input_tokens or 0)
161+
+ (json.message.usage.cache_creation_input_tokens or 0)
162162

163163
output_tokens = json.message.usage.output_tokens or 0
164164
elseif json.type == "message_delta" then
@@ -260,10 +260,11 @@ return {
260260
order = 1,
261261
mapping = "parameters",
262262
type = "enum",
263-
desc = "The model that will complete your prompt. See https://docs.anthropic.com/claude/docs/models-overview for additional details and options.",
263+
desc =
264+
"The model that will complete your prompt. See https://docs.anthropic.com/claude/docs/models-overview for additional details and options.",
264265
default = "claude-3-7-sonnet-20250219",
265266
choices = {
266-
"claude-3-7-sonnet-20250219",
267+
["claude-3-7-sonnet-20250219"] = { opts = { can_reason = true } },
267268
"claude-3-5-sonnet-20241022",
268269
"claude-3-5-haiku-20241022",
269270
"claude-3-opus-20240229",
@@ -285,25 +286,39 @@ return {
285286
optional = true,
286287
default = false,
287288
desc = "Enable extended thinking for more thorough reasoning. Requires thinking_budget to be set.",
289+
condition = function(schema)
290+
local model = schema.model.default
291+
if schema.model.choices[model] and schema.model.choices[model].opts then
292+
return schema.model.choices[model].opts.can_reason
293+
end
294+
end,
288295
},
289296
thinking_budget = {
290297
order = 4,
291298
mapping = "parameters",
292299
type = "number",
293300
optional = true,
294301
default = 16000,
295-
desc = "The maximum number of tokens to use for thinking when extended_thinking is enabled. Must be less than max_tokens.",
302+
desc =
303+
"The maximum number of tokens to use for thinking when extended_thinking is enabled. Must be less than max_tokens.",
296304
validate = function(n)
297305
return n > 0, "Must be greater than 0"
298306
end,
307+
condition = function(schema)
308+
local model = schema.model.default
309+
if schema.model.choices[model] and schema.model.choices[model].opts then
310+
return schema.model.choices[model].opts.can_reason
311+
end
312+
end,
299313
},
300314
max_tokens = {
301315
order = 5,
302316
mapping = "parameters",
303317
type = "number",
304318
optional = true,
305319
default = 4096,
306-
desc = "The maximum number of tokens to generate before stopping. This parameter only specifies the absolute maximum number of tokens to generate. Different models have different maximum values for this parameter.",
320+
desc =
321+
"The maximum number of tokens to generate before stopping. This parameter only specifies the absolute maximum number of tokens to generate. Different models have different maximum values for this parameter.",
307322
validate = function(n)
308323
return n > 0 and n <= 32768, "Must be between 0 and 32768"
309324
end,
@@ -314,7 +329,8 @@ return {
314329
type = "number",
315330
optional = true,
316331
default = 0,
317-
desc = "Amount of randomness injected into the response. Ranges from 0.0 to 1.0. Use temperature closer to 0.0 for analytical / multiple choice, and closer to 1.0 for creative and generative tasks. Note that even with temperature of 0.0, the results will not be fully deterministic.",
332+
desc =
333+
"Amount of randomness injected into the response. Ranges from 0.0 to 1.0. Use temperature closer to 0.0 for analytical / multiple choice, and closer to 1.0 for creative and generative tasks. Note that even with temperature of 0.0, the results will not be fully deterministic.",
318334
validate = function(n)
319335
return n >= 0 and n <= 1, "Must be between 0 and 1.0"
320336
end,
@@ -325,7 +341,8 @@ return {
325341
type = "number",
326342
optional = true,
327343
default = nil,
328-
desc = "Computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p",
344+
desc =
345+
"Computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p",
329346
validate = function(n)
330347
return n >= 0 and n <= 1, "Must be between 0 and 1"
331348
end,
@@ -336,7 +353,8 @@ return {
336353
type = "number",
337354
optional = true,
338355
default = nil,
339-
desc = "Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses",
356+
desc =
357+
"Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses",
340358
validate = function(n)
341359
return n >= 0, "Must be greater than 0"
342360
end,

0 commit comments

Comments
 (0)