Skip to content

Commit dd9225f

Browse files
committed
fix: make token counting optional in chat client
Currently token counting and limiting messages is enforced for all chat providers, but this should be optional as not all chat providers support token limiting. This change makes token counting optional and only enabled when max_tokens and tokenizer are provided in model config. The types in providers.lua have also been updated to reflect the optional nature of these fields. Signed-off-by: Tomas Slusny <[email protected]>
1 parent 1929831 commit dd9225f

File tree

2 files changed

+60
-44
lines changed

2 files changed

+60
-44
lines changed

lua/CopilotChat/client.lua

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,10 @@ function Client:fetch_models()
251251
if ok then
252252
for _, model in ipairs(provider_models) do
253253
model.provider = provider_name
254+
if not model.version then
255+
model.version = model.id
256+
end
257+
254258
if models[model.id] then
255259
model.id = model.id .. ':' .. provider_name
256260
model.version = model.version .. ':' .. provider_name
@@ -353,7 +357,10 @@ function Client:ask(prompt, opts)
353357
local tokenizer = model_config.tokenizer or 'o200k_base'
354358
log.debug('Max tokens: ', max_tokens)
355359
log.debug('Tokenizer: ', tokenizer)
356-
tiktoken.load(tokenizer)
360+
361+
if max_tokens and tokenizer then
362+
tiktoken.load(tokenizer)
363+
end
357364

358365
notify.publish(notify.STATUS, 'Generating request')
359366

@@ -369,45 +376,54 @@ function Client:ask(prompt, opts)
369376
local generated_messages = {}
370377
local selection_messages = generate_selection_messages(selection)
371378
local embeddings_messages = generate_embeddings_messages(embeddings)
372-
local generated_tokens = 0
373-
for _, message in ipairs(selection_messages) do
374-
generated_tokens = generated_tokens + tiktoken.count(message.content)
375-
table.insert(generated_messages, message)
376-
end
377-
378-
-- Count required tokens that we cannot reduce
379-
local prompt_tokens = tiktoken.count(prompt)
380-
local system_tokens = tiktoken.count(system_prompt)
381-
local required_tokens = prompt_tokens + system_tokens + generated_tokens
382-
383-
-- Reserve space for first embedding
384-
local reserved_tokens = #embeddings_messages > 0
385-
and tiktoken.count(embeddings_messages[1].content)
386-
or 0
387-
388-
-- Calculate how many tokens we can use for history
389-
local history_limit = max_tokens - required_tokens - reserved_tokens
390-
local history_tokens = 0
391-
for _, msg in ipairs(history) do
392-
history_tokens = history_tokens + tiktoken.count(msg.content)
393-
end
394-
395-
-- If we're over history limit, truncate history from the beginning
396-
while history_tokens > history_limit and #history > 0 do
397-
local removed = table.remove(history, 1)
398-
history_tokens = history_tokens - tiktoken.count(removed.content)
399-
end
400-
401-
-- Now add as many files as possible with remaining token budget (back to front)
402-
local remaining_tokens = max_tokens - required_tokens - history_tokens
403-
for i = #embeddings_messages, 1, -1 do
404-
local message = embeddings_messages[i]
405-
local tokens = tiktoken.count(message.content)
406-
if remaining_tokens - tokens >= 0 then
407-
remaining_tokens = remaining_tokens - tokens
379+
380+
if max_tokens then
381+
-- Count tokens from embeddings
382+
local generated_tokens = 0
383+
for _, message in ipairs(selection_messages) do
384+
generated_tokens = generated_tokens + tiktoken.count(message.content)
385+
table.insert(generated_messages, message)
386+
end
387+
388+
-- Count required tokens that we cannot reduce
389+
local prompt_tokens = tiktoken.count(prompt)
390+
local system_tokens = tiktoken.count(system_prompt)
391+
local required_tokens = prompt_tokens + system_tokens + generated_tokens
392+
393+
-- Reserve space for first embedding
394+
local reserved_tokens = #embeddings_messages > 0
395+
and tiktoken.count(embeddings_messages[1].content)
396+
or 0
397+
398+
-- Calculate how many tokens we can use for history
399+
local history_limit = max_tokens - required_tokens - reserved_tokens
400+
local history_tokens = 0
401+
for _, msg in ipairs(history) do
402+
history_tokens = history_tokens + tiktoken.count(msg.content)
403+
end
404+
405+
-- If we're over history limit, truncate history from the beginning
406+
while history_tokens > history_limit and #history > 0 do
407+
local removed = table.remove(history, 1)
408+
history_tokens = history_tokens - tiktoken.count(removed.content)
409+
end
410+
411+
-- Now add as many files as possible with remaining token budget (back to front)
412+
local remaining_tokens = max_tokens - required_tokens - history_tokens
413+
for i = #embeddings_messages, 1, -1 do
414+
local message = embeddings_messages[i]
415+
local tokens = tiktoken.count(message.content)
416+
if remaining_tokens - tokens >= 0 then
417+
remaining_tokens = remaining_tokens - tokens
418+
table.insert(generated_messages, message)
419+
else
420+
break
421+
end
422+
end
423+
else
424+
-- Add all embedding messages as we cant limit them
425+
for _, message in ipairs(embeddings_messages) do
408426
table.insert(generated_messages, message)
409-
else
410-
break
411427
end
412428
end
413429

lua/CopilotChat/config/providers.lua

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ local utils = require('CopilotChat.utils')
44
---@class CopilotChat.Provider.model
55
---@field id string
66
---@field name string
7-
---@field version string
8-
---@field tokenizer string
9-
---@field max_prompt_tokens number
10-
---@field max_output_tokens number
7+
---@field version string?
8+
---@field tokenizer string?
9+
---@field max_prompt_tokens number?
10+
---@field max_output_tokens number?
1111

1212
---@class CopilotChat.Provider.agent
1313
---@field id string
1414
---@field name string
15-
---@field description string
15+
---@field description string?
1616

1717
---@class CopilotChat.Provider
1818
---@field disabled nil|boolean

0 commit comments

Comments
 (0)