Skip to content

Commit 9016e66

Browse files
author
Calvinn Ng
committed
update core interfaces to include new attributes
1 parent edc0e70 commit 9016e66

File tree

5 files changed

+259
-104
lines changed

5 files changed

+259
-104
lines changed

core/config/handler.ts

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -92,42 +92,6 @@ export class ConfigHandler {
9292
return this.savedConfig;
9393
}
9494

95-
setupLlm(llm: ILLM): ILLM {
96-
llm._fetch = async (input, init) => {
97-
const resp = await fetchwithRequestOptions(
98-
new URL(input),
99-
{ ...init },
100-
llm.requestOptions
101-
);
102-
103-
if (!resp.ok) {
104-
let text = await resp.text();
105-
if (resp.status === 404 && !resp.url.includes("/v1")) {
106-
if (text.includes("try pulling it first")) {
107-
const model = JSON.parse(text).error.split(" ")[1].slice(1, -1);
108-
text = `The model "${model}" was not found. To download it, run \`ollama run ${model}\`.`;
109-
} else if (text.includes("/api/chat")) {
110-
text =
111-
"The /api/chat endpoint was not found. This may mean that you are using an older version of Ollama that does not support /api/chat. Upgrading to the latest version will solve the issue.";
112-
} else {
113-
text =
114-
"This may mean that you forgot to add '/v1' to the end of your 'apiBase' in config.json.";
115-
}
116-
}
117-
throw new Error(
118-
`HTTP ${resp.status} ${resp.statusText} from ${resp.url}\n\n${text}`
119-
);
120-
}
121-
122-
return resp;
123-
};
124-
125-
llm.writeLog = async (log: string) => {
126-
this.writeLog(log);
127-
};
128-
return llm;
129-
}
130-
13195
async llmFromTitle(title?: string): Promise<ILLM> {
13296
const config = await this.loadConfig();
13397
const model =
@@ -136,7 +100,7 @@ export class ConfigHandler {
136100
throw new Error("No model found");
137101
}
138102

139-
return this.setupLlm(model);
103+
return model;
140104
}
141105

142106
async loadCommandLlm(title?: string): Promise<ILLM> {
@@ -147,6 +111,6 @@ export class ConfigHandler {
147111
throw new Error("No commandModel found");
148112
}
149113

150-
return this.setupLlm(model);
114+
return model;
151115
}
152116
}

core/config/load.ts

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ import {
2525
import { contextProviderClassFromName } from "../context/providers";
2626
import CustomContextProviderClass from "../context/providers/CustomContextProvider";
2727
import FileContextProvider from "../context/providers/FileContextProvider";
28-
import { AllRerankers } from "../context/rerankers";
29-
import { LLMReranker } from "../context/rerankers/llm";
28+
// import { AllRerankers } from "../context/rerankers";
29+
// import { LLMReranker } from "../context/rerankers/llm";
3030
import { AllEmbeddingsProviders } from "../indexing/embeddings";
3131
import TransformersJsEmbeddingsProvider from "../indexing/embeddings/TransformersJsEmbeddingsProvider";
3232
import { BaseLLM } from "../llm";
@@ -36,9 +36,7 @@ import { copyOf } from "../util";
3636
import mergeJson from "../util/merge";
3737
import {
3838
getConfigJsPath,
39-
getConfigJsPathForRemote,
4039
getConfigJsonPath,
41-
getConfigJsonPathForRemote,
4240
getConfigTsPath,
4341
getContinueDotEnv,
4442
migrate,
@@ -278,36 +276,36 @@ async function intermediateToFinalConfig(
278276
}
279277
}
280278

281-
// Embeddings Provider
282-
if (
283-
(config.embeddingsProvider as EmbeddingsProviderDescription | undefined)
284-
?.provider
285-
) {
286-
const { provider, ...options } =
287-
config.embeddingsProvider as EmbeddingsProviderDescription;
288-
config.embeddingsProvider = new AllEmbeddingsProviders[provider](options);
289-
}
290-
291-
if (!config.embeddingsProvider) {
292-
config.embeddingsProvider = new TransformersJsEmbeddingsProvider();
293-
}
279+
// // Embeddings Provider
280+
// if (
281+
// (config.embeddingsProvider as EmbeddingsProviderDescription | undefined)
282+
// ?.provider
283+
// ) {
284+
// const { provider, ...options } =
285+
// config.embeddingsProvider as EmbeddingsProviderDescription;
286+
// config.embeddingsProvider = new AllEmbeddingsProviders[provider](options);
287+
// }
294288

295-
// Reranker
296-
if (config.reranker && !(config.reranker as Reranker | undefined)?.rerank) {
297-
const { name, params } = config.reranker as RerankerDescription;
298-
const rerankerClass = AllRerankers[name];
289+
// if (!config.embeddingsProvider) {
290+
// config.embeddingsProvider = new TransformersJsEmbeddingsProvider();
291+
// }
299292

300-
if (name === "llm") {
301-
const llm = models.find((model) => model.title === params?.modelTitle);
302-
if (!llm) {
303-
console.warn(`Unknown model ${params?.modelTitle}`);
304-
} else {
305-
config.reranker = new LLMReranker(llm);
306-
}
307-
} else if (rerankerClass) {
308-
config.reranker = new rerankerClass(params);
309-
}
310-
}
293+
// // Reranker
294+
// if (config.reranker && !(config.reranker as Reranker | undefined)?.rerank) {
295+
// const { name, params } = config.reranker as RerankerDescription;
296+
// const rerankerClass = AllRerankers[name];
297+
298+
// if (name === "llm") {
299+
// const llm = models.find((model) => model.title === params?.modelTitle);
300+
// if (!llm) {
301+
// console.warn(`Unknown model ${params?.modelTitle}`);
302+
// } else {
303+
// config.reranker = new LLMReranker(llm);
304+
// }
305+
// } else if (rerankerClass) {
306+
// config.reranker = new rerankerClass(params);
307+
// }
308+
// }
311309

312310
return {
313311
...config,
@@ -336,7 +334,8 @@ function finalToBrowserConfig(
336334
completionOptions: m.completionOptions,
337335
systemMessage: m.systemMessage,
338336
requestOptions: m.requestOptions,
339-
promptTemplates: m.promptTemplates,
337+
// TODO: Types incompanitable. Correct them.
338+
// promptTemplates: m.promptTemplates,
340339
})),
341340
systemMessage: final.systemMessage,
342341
completionOptions: final.completionOptions,

core/context/rerankers/cohere.ts

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import fetch from "node-fetch";
2+
import { Chunk, Reranker } from "../../index.js";
3+
4+
export class CohereReranker implements Reranker {
5+
name = "cohere";
6+
7+
static defaultOptions = {
8+
apiBase: "https://api.cohere.ai/v1/",
9+
model: "rerank-english-v3.0",
10+
};
11+
12+
constructor(
13+
private readonly params: {
14+
apiBase?: string;
15+
apiKey: string;
16+
model?: string;
17+
},
18+
) {}
19+
20+
async rerank(query: string, chunks: Chunk[]): Promise<number[]> {
21+
let apiBase = this.params.apiBase ?? CohereReranker.defaultOptions.apiBase;
22+
if (!apiBase.endsWith("/")) {
23+
apiBase += "/";
24+
}
25+
26+
const resp = await fetch(new URL("rerank", apiBase), {
27+
method: "POST",
28+
headers: {
29+
Authorization: `Bearer ${this.params.apiKey}`,
30+
"Content-Type": "application/json",
31+
},
32+
body: JSON.stringify({
33+
model: this.params.model ?? CohereReranker.defaultOptions.model,
34+
query,
35+
documents: chunks.map((chunk) => chunk.content),
36+
}),
37+
});
38+
39+
if (!resp.ok) {
40+
throw new Error(await resp.text());
41+
}
42+
43+
const data = (await resp.json()) as any;
44+
const results = data.results.sort((a: any, b: any) => a.index - b.index);
45+
return results.map((result: any) => result.relevance_score);
46+
}
47+
}

0 commit comments

Comments
 (0)