Skip to content

Commit eaf4b9d

Browse files
Merge pull request #152 from olasunkanmi-SE/ai-agents
feat(llms): Introduce Message class for structured chat history
2 parents 7cd0a42 + 689610d commit eaf4b9d

File tree

12 files changed

+198
-192
lines changed

12 files changed

+198
-192
lines changed

src/agents/orchestrator.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export class Orchestrator extends BaseAiAgent implements vscode.Disposable {
1010
super();
1111
this.disposables.push(
1212
this.onStatusChange(this.handleStatus.bind(this)),
13-
this.onPrompt(this.handleQuery.bind(this)),
13+
this.onPromptGenerated(this.handlePromptGeneratedEvent.bind(this)),
1414
this.onError(this.handleError.bind(this)),
1515
);
1616
}
@@ -23,12 +23,12 @@ export class Orchestrator extends BaseAiAgent implements vscode.Disposable {
2323
}
2424

2525
public handleStatus(event: IEventPayload) {
26-
this.publish("onQuery", JSON.stringify(event));
2726
console.log(` ${event.message} - ${JSON.stringify(event)}`);
2827
}
2928

30-
public handleQuery(event: IEventPayload) {
29+
public handlePromptGeneratedEvent(event: IEventPayload) {
3130
console.error(`Error: ${event.message})`);
31+
this.publish("onResponse", JSON.stringify(event));
3232
}
3333

3434
public handleError(event: IEventPayload) {

src/emitter/agent-emitter.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ import * as vscode from "vscode";
22
import { BaseEmitter } from "./emitter";
33
import { IAgentEventMap, IEventPayload } from "./interface";
44

5-
export class EventEmitter extends BaseEmitter<IAgentEventMap> {
5+
export class EventEmitter extends BaseEmitter<Record<string, IEventPayload>> {
66
onStatusChange: vscode.Event<IEventPayload> = this.createEvent("onStatus");
77
onError: vscode.Event<IEventPayload> = this.createEvent("onError");
88
onUpdate: vscode.Event<IEventPayload> = this.createEvent("onUpdate");
9-
onPrompt: vscode.Event<IEventPayload> = this.createEvent("onQuery");
9+
onPromptGenerated: vscode.Event<IEventPayload> = this.createEvent("onQuery");
10+
onThinking: vscode.Event<IEventPayload> = this.createEvent("onThinking");
11+
onResponse: vscode.Event<IEventPayload> = this.createEvent("onResponse");
1012

1113
/**
1214
* Emits a generic event with specified status, message, and optional data.

src/emitter/emitter.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import * as vscode from "vscode";
22
import { Logger } from "../infrastructure/logger/logger";
3-
export class BaseEmitter<EventMap extends Record<string, any>> {
3+
import { IEventPayload } from "./interface";
4+
export class BaseEmitter<EventMap extends Record<string, IEventPayload>> {
45
protected logger: Logger;
56
constructor() {
67
this.logger = new Logger("BaseEmitter");

src/emitter/interface.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ export interface IAgentEventMap {
1616
onError: IEventPayload;
1717
onUpdate: IEventPayload;
1818
onQuery: IEventPayload;
19+
onResponse: IEventPayload;
20+
onThinking: IEventPayload;
1921
}
2022

2123
export interface IEventPayload {

src/llms/gemini/gemini.ts

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,28 @@ import {
44
GenerativeModel,
55
GoogleGenerativeAI,
66
} from "@google/generative-ai";
7+
import * as vscode from "vscode";
8+
import { Orchestrator } from "../../agents/orchestrator";
9+
import { ProcessInputResult } from "../../application/interfaces/agent.interface";
10+
import { createPrompt } from "../../utils/prompt";
711
import { BaseLLM } from "../base";
812
import { GeminiModelResponseType, ILlmConfig } from "../interface";
913

10-
export class GeminiLLM extends BaseLLM<GeminiModelResponseType> {
14+
export class GeminiLLM
15+
extends BaseLLM<GeminiModelResponseType>
16+
implements vscode.Disposable
17+
{
1118
private readonly generativeAi: GoogleGenerativeAI;
1219
private response: EmbedContentResponse | GenerateContentResult | undefined;
20+
protected readonly orchestrator: Orchestrator;
21+
private readonly disposables: vscode.Disposable[] = [];
1322

1423
constructor(config: ILlmConfig) {
1524
super(config);
1625
this.config = config;
1726
this.generativeAi = new GoogleGenerativeAI(this.config.apiKey);
1827
this.response = undefined;
28+
this.orchestrator = Orchestrator.getInstance();
1929
}
2030

2131
public async generateEmbeddings(text: string): Promise<number[]> {
@@ -42,7 +52,7 @@ export class GeminiLLM extends BaseLLM<GeminiModelResponseType> {
4252
}
4353
}
4454

45-
private getModel(): GenerativeModel {
55+
getModel(): GenerativeModel {
4656
try {
4757
const model: GenerativeModel | undefined =
4858
this.generativeAi.getGenerativeModel({
@@ -59,11 +69,47 @@ export class GeminiLLM extends BaseLLM<GeminiModelResponseType> {
5969
}
6070
}
6171

72+
async generateContent(
73+
userInput: string,
74+
): Promise<Partial<ProcessInputResult>> {
75+
try {
76+
const prompt = createPrompt(userInput);
77+
const model = this.getModel();
78+
const generateContentResponse: GenerateContentResult =
79+
await model.generateContent(prompt);
80+
const { text, usageMetadata } = generateContentResponse.response;
81+
const parsedResponse = this.orchestrator.parseResponse(text());
82+
const extractedQueries = parsedResponse.queries;
83+
const extractedThought = parsedResponse.thought;
84+
const tokenCount = usageMetadata?.totalTokenCount ?? 0;
85+
const result = {
86+
queries: extractedQueries,
87+
tokens: tokenCount,
88+
prompt: userInput,
89+
thought: extractedThought,
90+
};
91+
this.orchestrator.publish("onQuery", JSON.stringify(result));
92+
return result;
93+
} catch (error: any) {
94+
this.orchestrator.publish("onError", error);
95+
vscode.window.showErrorMessage("Error processing user query");
96+
this.logger.error(
97+
"Error generating, queries, thoughts from user query",
98+
error,
99+
);
100+
throw error;
101+
}
102+
}
103+
62104
public createSnapShot(data?: any): GeminiModelResponseType {
63105
return { ...this.response, ...data };
64106
}
65107

66108
public loadSnapShot(snapshot: ReturnType<typeof this.createSnapShot>): void {
67109
Object.assign(this, snapshot);
68110
}
111+
112+
public dispose(): void {
113+
this.disposables.forEach((d) => d.dispose());
114+
}
69115
}

src/llms/interface.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export interface ILlmConfig {
99
model: string;
1010
tools?: any[];
1111
apiKey: string;
12-
baseUrl: string;
12+
baseUrl?: string;
1313
systemInstruction?: string;
1414
cachedContent?: any;
1515
additionalConfig?: Record<string, any>;

src/llms/message.ts

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import { Part } from "@google/generative-ai";
2+
3+
type Role = "function" | "user" | "model" | "assistant" | "system";
4+
5+
export interface IBaseMessage {
6+
[key: string]: any;
7+
createdAt?: string;
8+
}
9+
10+
export interface IMessageInput {
11+
role: Role;
12+
parts?: Part[];
13+
content?: string;
14+
}
15+
16+
export class Message {
17+
constructor(
18+
readonly role: Role,
19+
readonly content: string = "",
20+
readonly parts: Part[] = [],
21+
) {}
22+
23+
static of({ role, content, parts }: IMessageInput) {
24+
return new Message(role, content, parts);
25+
}
26+
27+
createSnapShot(): IMessageInput {
28+
return {
29+
role: this.role,
30+
content: this.content,
31+
parts: this.parts,
32+
};
33+
}
34+
35+
loadSnapShot(state: IMessageInput) {
36+
return Object.assign(this, state);
37+
}
38+
}

src/providers/anthropic.ts

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,10 @@ import {
1212
getXGroKBaseURL,
1313
} from "../utils/utils";
1414
import { Memory } from "../memory/base";
15-
16-
type Role = "user" | "assistant";
17-
export interface IHistory {
18-
role: Role;
19-
content: string;
20-
}
15+
import { IMessageInput, Message } from "../llms/message";
2116

2217
export class AnthropicWebViewProvider extends BaseWebViewProvider {
23-
chatHistory: IHistory[] = [];
18+
chatHistory: IMessageInput[] = [];
2419
readonly model: Anthropic;
2520
constructor(
2621
extensionUri: vscode.Uri,
@@ -40,15 +35,19 @@ export class AnthropicWebViewProvider extends BaseWebViewProvider {
4035
try {
4136
const type = currentChat === "bot" ? "bot-response" : "user-input";
4237
if (currentChat === "bot") {
43-
this.chatHistory.push({
44-
role: "assistant",
45-
content: response,
46-
});
38+
this.chatHistory.push(
39+
Message.of({
40+
role: "assistant",
41+
content: response,
42+
}),
43+
);
4744
} else {
48-
this.chatHistory.push({
49-
role: "user",
50-
content: response,
51-
});
45+
this.chatHistory.push(
46+
Message.of({
47+
role: "user",
48+
content: response,
49+
}),
50+
);
5251
}
5352

5453
if (this.chatHistory.length === 2) {
@@ -69,27 +68,18 @@ export class AnthropicWebViewProvider extends BaseWebViewProvider {
6968
}
7069
}
7170

72-
async generateResponse(
73-
message: string,
74-
apiKey?: string,
75-
name?: string,
76-
): Promise<string | undefined> {
71+
async generateResponse(message: string): Promise<string | undefined> {
7772
try {
7873
const { max_tokens } = GROQ_CONFIG;
7974
if (getGenerativeAiModel() === generativeAiModels.GROK) {
8075
this.baseUrl = getXGroKBaseURL();
8176
}
77+
const userMessage = Message.of({ role: "user", content: message });
8278
let chatHistory = Memory.has(COMMON.ANTHROPIC_CHAT_HISTORY)
8379
? Memory.get(COMMON.ANTHROPIC_CHAT_HISTORY)
84-
: [];
80+
: [userMessage];
8581

86-
if (chatHistory?.length) {
87-
chatHistory = [...chatHistory, { role: "user", content: message }];
88-
}
89-
90-
if (!chatHistory?.length) {
91-
chatHistory = [{ role: "user", content: message }];
92-
}
82+
chatHistory = [...chatHistory, userMessage];
9383

9484
Memory.removeItems(COMMON.ANTHROPIC_CHAT_HISTORY);
9585

src/providers/base.ts

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ export abstract class BaseWebViewProvider implements vscode.Disposable {
2626
this.orchestrator = Orchestrator.getInstance();
2727
this.logger = new Logger("BaseWebViewProvider");
2828
this.disposables.push(
29-
this.orchestrator.onUpdate(this.subscribeToUpdate.bind(this)),
29+
this.orchestrator.onResponse(this.handleModelResponseEvent.bind(this)),
30+
this.orchestrator.onThinking(this.handleThinkingEvent.bind(this)),
3031
);
3132
}
3233

@@ -75,16 +76,15 @@ export abstract class BaseWebViewProvider implements vscode.Disposable {
7576
): Promise<void> {
7677
try {
7778
_view.webview.onDidReceiveMessage(async (message) => {
78-
let response;
79+
let response: any;
7980
if (message.command === "user-input") {
8081
if (message.tags?.length > 0) {
81-
this.publishEvent(message);
82-
} else {
8382
response = await this.generateResponse(
8483
message.message,
85-
apiKey,
86-
modelName,
84+
message.tags,
8785
);
86+
} else {
87+
response = await this.generateResponse(message.message);
8888
}
8989

9090
if (response) {
@@ -98,38 +98,24 @@ export abstract class BaseWebViewProvider implements vscode.Disposable {
9898
}
9999
}
100100

101-
public subscribeToUpdate(event: IEventPayload) {
102-
this.sendResponse(JSON.stringify(event));
103-
console.error(
104-
`Error: ${event.message} (Code: ${JSON.stringify(event.data)})`,
105-
);
101+
public handleModelResponseEvent(event: IEventPayload) {
102+
this.sendResponse(formatText(event.message), "bot");
106103
}
107104

108-
public async publishEvent(message: any) {
109-
try {
110-
const response = await this.generateContent(message.message);
111-
if (response) {
112-
this.orchestrator.publish("onQuery", JSON.stringify(response));
113-
}
114-
} catch (error) {
115-
this.logger.error("Unable to publish generateContentEvent", error);
116-
throw error;
117-
}
105+
public handleThinkingEvent(event: IEventPayload) {
106+
this.sendResponse(formatText(event.message), "bot");
118107
}
119108

120109
abstract generateResponse(
121110
message?: string,
122-
apiKey?: string,
123-
name?: string,
111+
metaData?: Record<string, any>,
124112
): Promise<string | undefined>;
125113

126114
abstract sendResponse(
127115
response: string,
128116
currentChat?: string,
129117
): Promise<boolean | undefined>;
130118

131-
abstract generateContent(userInput: string): any;
132-
133119
public dispose(): void {
134120
this.disposables.forEach((d) => d.dispose());
135121
}

0 commit comments

Comments
 (0)