Skip to content

Commit 75b9a47

Browse files
Add memoized model info, and checks against size.
1 parent b7308d8 commit 75b9a47

File tree

1 file changed

+48
-6
lines changed

1 file changed

+48
-6
lines changed

src/api/providers/native-ollama.ts

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from ".
1111
// kilocode_change start
1212
import { fetchWithTimeout } from "./kilocode/fetchWithTimeout"
1313
const OLLAMA_TIMEOUT_MS = 3_600_000
14+
15+
const TOKEN_ESTIMATION_FACTOR = 4 //Industry standard technique for estimating token counts without actually implementing a parser/tokenizer
16+
17+
function estimateOllamaTokenCount(messages: Message[]): number {
18+
const totalChars = messages.reduce((acc, msg) => acc + (msg.content?.length || 0), 0)
19+
return Math.ceil(totalChars / TOKEN_ESTIMATION_FACTOR)
20+
}
1421
// kilocode_change end
1522

1623
function convertToOllamaMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): Message[] {
@@ -136,18 +143,28 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio
136143
protected options: ApiHandlerOptions
137144
private client: Ollama | undefined
138145
protected models: Record<string, ModelInfo> = {}
146+
private isInitialized = false
139147

140148
constructor(options: ApiHandlerOptions) {
141149
super()
142150
this.options = options
151+
this.initialize()
152+
}
153+
154+
private async initialize(): Promise<void> {
155+
if (this.isInitialized) {
156+
return
157+
}
158+
await this.fetchModels()
159+
this.isInitialized = true
143160
}
144161

145162
private ensureClient(): Ollama {
146163
if (!this.client) {
147164
try {
148165
// kilocode_change start
149166
const headers = this.options.ollamaApiKey
150-
? { Authorization: this.options.ollamaApiKey } //Yes, this is weird, its not a Bearer token
167+
? { Authorization: this.options.ollamaApiKey } // Yes, this is weird, its not a Bearer token
151168
: undefined
152169
// kilocode_change end
153170

@@ -170,15 +187,26 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio
170187
messages: Anthropic.Messages.MessageParam[],
171188
metadata?: ApiHandlerCreateMessageMetadata,
172189
): ApiStream {
190+
if (!this.isInitialized) {
191+
await this.initialize()
192+
}
193+
173194
const client = this.ensureClient()
174-
const { id: modelId, info: modelInfo } = await this.fetchModel()
195+
const { id: modelId, info: modelInfo } = this.getModel()
175196
const useR1Format = modelId.toLowerCase().includes("deepseek-r1")
176197

177198
const ollamaMessages: Message[] = [
178199
{ role: "system", content: systemPrompt },
179200
...convertToOllamaMessages(messages),
180201
]
181202

203+
const estimatedTokenCount = estimateOllamaTokenCount(ollamaMessages)
204+
if (modelInfo.maxTokens && estimatedTokenCount > modelInfo.maxTokens) {
205+
throw new Error(
206+
`Input message is too long for the selected model. Estimated tokens: ${estimatedTokenCount}, Max tokens: ${modelInfo.maxTokens}`,
207+
)
208+
}
209+
182210
const matcher = new XmlMatcher(
183211
"think",
184212
(chunk) =>
@@ -259,23 +287,37 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio
259287
}
260288
}
261289

262-
async fetchModel() {
290+
async fetchModels() {
263291
this.models = await getOllamaModels(this.options.ollamaBaseUrl)
264-
return this.getModel()
292+
return this.models
265293
}
266294

267295
override getModel(): { id: string; info: ModelInfo } {
268296
const modelId = this.options.ollamaModelId || ""
297+
const modelInfo = this.models[modelId]
298+
299+
if (!modelInfo) {
300+
const availableModels = Object.keys(this.models)
301+
const errorMessage =
302+
availableModels.length > 0
303+
? `Model ${modelId} not found. Available models: ${availableModels.join(", ")}`
304+
: `Model ${modelId} not found. No models available.`
305+
throw new Error(errorMessage)
306+
}
307+
269308
return {
270309
id: modelId,
271-
info: this.models[modelId] || openAiModelInfoSaneDefaults,
310+
info: modelInfo,
272311
}
273312
}
274313

275314
async completePrompt(prompt: string): Promise<string> {
276315
try {
316+
if (!this.isInitialized) {
317+
await this.initialize()
318+
}
277319
const client = this.ensureClient()
278-
const { id: modelId } = await this.fetchModel()
320+
const { id: modelId } = this.getModel()
279321
const useR1Format = modelId.toLowerCase().includes("deepseek-r1")
280322

281323
const response = await client.chat({

0 commit comments

Comments
 (0)