Skip to content

Commit 76844a0

Browse files
committed
Merge branch 'main' of https://github.com/continuedev/continue into snyk-upgrade-3095d9fd5e6b0170e79d83641d320c82
2 parents 618e053 + d3756a1 commit 76844a0

File tree

15 files changed

+2768
-1656
lines changed

15 files changed

+2768
-1656
lines changed

core/config/yaml/loadYaml.ts

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import {
22
AssistantUnrolled,
3+
AssistantUnrolledNonNullable,
34
BLOCK_TYPES,
45
ConfigResult,
56
ConfigValidationError,
6-
isAssistantUnrolledNonNullable,
77
mergeConfigYamlRequestOptions,
88
mergeUnrolledAssistants,
99
ModelRole,
@@ -145,8 +145,8 @@ async function loadConfigYaml(options: {
145145
}
146146
}
147147

148-
if (config && isAssistantUnrolledNonNullable(config)) {
149-
errors.push(...validateConfigYaml(config));
148+
if (config) {
149+
errors.push(...validateConfigYaml(nonNullifyConfigYaml(config)));
150150
}
151151

152152
if (errors?.some((error) => error.fatal)) {
@@ -165,15 +165,30 @@ async function loadConfigYaml(options: {
165165
};
166166
}
167167

168+
function nonNullifyConfigYaml(
169+
unrolledAssistant: AssistantUnrolled,
170+
): AssistantUnrolledNonNullable {
171+
return {
172+
...unrolledAssistant,
173+
data: unrolledAssistant.data?.filter((k) => !!k),
174+
context: unrolledAssistant.context?.filter((k) => !!k),
175+
docs: unrolledAssistant.docs?.filter((k) => !!k),
176+
mcpServers: unrolledAssistant.mcpServers?.filter((k) => !!k),
177+
models: unrolledAssistant.models?.filter((k) => !!k),
178+
prompts: unrolledAssistant.prompts?.filter((k) => !!k),
179+
rules: unrolledAssistant.rules?.filter((k) => !!k).map((k) => k!),
180+
};
181+
}
182+
168183
export async function configYamlToContinueConfig(options: {
169-
config: AssistantUnrolled;
184+
unrolledAssistant: AssistantUnrolled;
170185
ide: IDE;
171186
ideInfo: IdeInfo;
172187
uniqueId: string;
173188
llmLogger: ILLMLogger;
174189
workOsAccessToken: string | undefined;
175190
}): Promise<{ config: ContinueConfig; errors: ConfigValidationError[] }> {
176-
let { config, ide, ideInfo, uniqueId, llmLogger } = options;
191+
let { unrolledAssistant, ide, ideInfo, uniqueId, llmLogger } = options;
177192

178193
const localErrors: ConfigValidationError[] = [];
179194

@@ -203,22 +218,10 @@ export async function configYamlToContinueConfig(options: {
203218
subagent: null,
204219
},
205220
rules: [],
206-
requestOptions: { ...config.requestOptions },
221+
requestOptions: { ...unrolledAssistant.requestOptions },
207222
};
208223

209-
// Right now, if there are any missing packages in the config, then we will just throw an error
210-
if (!isAssistantUnrolledNonNullable(config)) {
211-
return {
212-
config: continueConfig,
213-
errors: [
214-
{
215-
message:
216-
"Failed to load config due to missing blocks, see which blocks are missing below",
217-
fatal: true,
218-
},
219-
],
220-
};
221-
}
224+
const config = nonNullifyConfigYaml(unrolledAssistant);
222225

223226
for (const rule of config.rules ?? []) {
224227
const convertedRule = convertYamlRuleToContinueRule(rule);
@@ -447,7 +450,7 @@ export async function loadContinueConfigFromYaml(options: {
447450

448451
const { config: continueConfig, errors: localErrors } =
449452
await configYamlToContinueConfig({
450-
config: configYamlResult.config,
453+
unrolledAssistant: configYamlResult.config,
451454
ide,
452455
ideInfo,
453456
uniqueId,

core/llm/countTokens.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import {
1818
import { renderChatMessage } from "../util/messageContent.js";
1919
import { AsyncEncoder, LlamaAsyncEncoder } from "./asyncEncoder.js";
2020
import { DEFAULT_PRUNING_LENGTH } from "./constants.js";
21+
import { getAdjustedTokenCountFromModel } from "./getAdjustedTokenCount.js";
2122
import llamaTokenizer from "./llamaTokenizer.js";
2223
interface Encoding {
2324
encode: Tiktoken["encode"];
@@ -114,8 +115,9 @@ function countTokens(
114115
modelName = "llama2",
115116
): number {
116117
const encoding = encodingForModel(modelName);
118+
let baseTokens = 0;
117119
if (Array.isArray(content)) {
118-
return content.reduce((acc, part) => {
120+
baseTokens = content.reduce((acc, part) => {
119121
return (
120122
acc +
121123
(part.type === "text"
@@ -124,8 +126,9 @@ function countTokens(
124126
);
125127
}, 0);
126128
} else {
127-
return encoding.encode(content ?? "", "all", []).length;
129+
baseTokens = encoding.encode(content ?? "", "all", []).length;
128130
}
131+
return getAdjustedTokenCountFromModel(baseTokens, modelName);
129132
}
130133

131134
// https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/10
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import { getAdjustedTokenCountFromModel } from "./getAdjustedTokenCount";
2+
3+
describe("getAdjustedTokenCountFromModel", () => {
4+
it("should return base tokens for non-special models", () => {
5+
expect(getAdjustedTokenCountFromModel(100, "gpt-4")).toBe(100);
6+
expect(getAdjustedTokenCountFromModel(100, "llama2")).toBe(100);
7+
expect(getAdjustedTokenCountFromModel(100, "random-model")).toBe(100);
8+
});
9+
10+
it("should apply multiplier for Claude models", () => {
11+
expect(getAdjustedTokenCountFromModel(100, "claude-3-opus")).toBe(123);
12+
expect(getAdjustedTokenCountFromModel(100, "claude-3.5-sonnet")).toBe(123);
13+
expect(getAdjustedTokenCountFromModel(100, "CLAUDE-2")).toBe(123);
14+
expect(getAdjustedTokenCountFromModel(50, "claude")).toBe(62); // 50 * 1.23 = 61.5, ceiled to 62
15+
});
16+
17+
it("should apply multiplier for Gemini models", () => {
18+
expect(getAdjustedTokenCountFromModel(100, "gemini-pro")).toBe(118);
19+
expect(getAdjustedTokenCountFromModel(100, "gemini-1.5-pro")).toBe(118);
20+
expect(getAdjustedTokenCountFromModel(100, "GEMINI-flash")).toBe(118);
21+
expect(getAdjustedTokenCountFromModel(50, "gemini")).toBe(59); // 50 * 1.18 = 59
22+
});
23+
24+
it("should apply multiplier for Mistral family models", () => {
25+
expect(getAdjustedTokenCountFromModel(100, "mistral-large")).toBe(126);
26+
expect(getAdjustedTokenCountFromModel(100, "mixtral-8x7b")).toBe(126);
27+
expect(getAdjustedTokenCountFromModel(100, "devstral")).toBe(126);
28+
expect(getAdjustedTokenCountFromModel(100, "CODESTRAL")).toBe(126);
29+
expect(getAdjustedTokenCountFromModel(50, "mistral")).toBe(63); // 50 * 1.26 = 63
30+
});
31+
32+
it("should handle edge cases", () => {
33+
expect(getAdjustedTokenCountFromModel(0, "claude")).toBe(0);
34+
expect(getAdjustedTokenCountFromModel(1, "gemini")).toBe(2); // 1 * 1.18 = 1.18, ceiled to 2
35+
expect(getAdjustedTokenCountFromModel(1000, "mixtral")).toBe(1260);
36+
});
37+
38+
it("should handle empty or undefined model names", () => {
39+
expect(getAdjustedTokenCountFromModel(100, "")).toBe(100);
40+
expect(getAdjustedTokenCountFromModel(100, undefined as any)).toBe(100);
41+
});
42+
43+
it("should be case-insensitive", () => {
44+
expect(getAdjustedTokenCountFromModel(100, "ClAuDe-3-OpUs")).toBe(123);
45+
expect(getAdjustedTokenCountFromModel(100, "GeMiNi-PrO")).toBe(118);
46+
expect(getAdjustedTokenCountFromModel(100, "MiXtRaL")).toBe(126);
47+
});
48+
});

core/llm/getAdjustedTokenCount.ts

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Importing a bunch of tokenizers can be very resource intensive (MB-scale per tokenizer)
2+
// Using token counting APIs (e.g. for anthropic) can be complicated and unreliable in many environments
3+
// So for now we will just use super fast gpt-tokenizer and apply safety buffers
4+
// I'm using rough estimates from this article to apply safety buffers to common tokenizers
5+
// which will have HIGHER token counts than gpt. Roughly using token ratio from article + 10%
6+
// https://medium.com/@disparate-ai/not-all-tokens-are-created-equal-7347d549af4d
7+
const ANTHROPIC_TOKEN_MULTIPLIER = 1.23;
8+
const GEMINI_TOKEN_MULTIPLIER = 1.18;
9+
const MISTRAL_TOKEN_MULTIPLIER = 1.26;
10+
11+
/**
12+
* Adjusts token count based on model-specific tokenizer differences.
13+
* Since we use llama tokenizer (~= gpt tokenizer) for all models, we apply
14+
* multipliers for models known to have higher token counts.
15+
*
16+
* @param baseTokens - Token count from llama/gpt tokenizer
17+
* @param modelName - Name of the model
18+
* @returns Adjusted token count with safety buffer
19+
*/
20+
export function getAdjustedTokenCountFromModel(
21+
baseTokens: number,
22+
modelName: string,
23+
): number {
24+
let multiplier = 1;
25+
const lowerModelName = modelName?.toLowerCase() ?? "";
26+
if (lowerModelName.includes("claude")) {
27+
multiplier = ANTHROPIC_TOKEN_MULTIPLIER;
28+
} else if (lowerModelName.includes("gemini")) {
29+
multiplier = GEMINI_TOKEN_MULTIPLIER;
30+
} else if (
31+
lowerModelName.includes("stral") ||
32+
lowerModelName.includes("mixtral")
33+
) {
34+
// Mistral family models: mistral, mixtral, codestral, devstral, etc
35+
multiplier = MISTRAL_TOKEN_MULTIPLIER;
36+
}
37+
return Math.ceil(baseTokens * multiplier);
38+
}

0 commit comments

Comments
 (0)