Skip to content

Commit 360af3a

Browse files
committed
Support for a result set tool
Signed-off-by: worksofliam <[email protected]>
1 parent d4f6b0e commit 360af3a

File tree

4 files changed

+154
-16
lines changed

4 files changed

+154
-16
lines changed

package.json

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,31 @@
13001300
]
13011301
}
13021302
],
1303+
"languageModelTools": [
1304+
{
1305+
"name": "vscode-db2i-chat-sqlRunnerTool",
1306+
"tags": [
1307+
"sql"
1308+
],
1309+
"canBeReferencedInPrompt": true,
1310+
"toolReferenceName": "result",
1311+
"displayName": "Run SQL statement",
1312+
"icon": "$(play)",
1313+
"modelDescription": "Run an SQL statement and return the result",
1314+
"inputSchema": {
1315+
"type": "object",
1316+
"properties": {
1317+
"statement": {
1318+
"type": "string",
1319+
"description": "The statement to execute"
1320+
}
1321+
},
1322+
"required": [
1323+
"statement"
1324+
]
1325+
}
1326+
}
1327+
],
13031328
"snippets": [
13041329
{
13051330
"language": "sql",

src/aiProviders/copilot/contributes.json

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,31 @@
1414
}
1515
]
1616
}
17+
],
18+
"languageModelTools": [
19+
{
20+
"name": "vscode-db2i-chat-sqlRunnerTool",
21+
"tags": [
22+
"sql"
23+
],
24+
"canBeReferencedInPrompt": true,
25+
"toolReferenceName": "result",
26+
"displayName": "Run SQL statement",
27+
"icon": "$(play)",
28+
"modelDescription": "Run an SQL statement and return the result",
29+
"inputSchema": {
30+
"type": "object",
31+
"properties": {
32+
"statement": {
33+
"type": "string",
34+
"description": "The statement to execute"
35+
}
36+
},
37+
"required": [
38+
"statement"
39+
]
40+
}
41+
}
1742
]
1843
}
1944
}

src/aiProviders/copilot/index.ts

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import {
33
canTalkToDb,
44
} from "../context";
55
import { buildPrompt, Db2ContextItems } from "../prompt";
6+
import { registerSqlRunTool, RUN_SQL_TOOL_ID } from "./sqlTool";
67

78
const CHAT_ID = `vscode-db2i.chat`;
89

@@ -56,7 +57,7 @@ export function activateChat(context: vscode.ExtensionContext) {
5657
stream.progress(`Building response...`);
5758

5859
// get history
59-
let history: Db2ContextItems[]|undefined;
60+
let history: Db2ContextItems[] | undefined;
6061
if (context.history.length > 0) {
6162
history = context.history.map((h) => {
6263
if ("prompt" in h) {
@@ -86,21 +87,49 @@ export function activateChat(context: vscode.ExtensionContext) {
8687
progress: stream.progress
8788
});
8889

89-
const messages = contextItems.context.map(c => {
90+
let messages = contextItems.context.map(c => {
9091
if (c.type === `user`) {
9192
return vscode.LanguageModelChatMessage.User(c.content);
9293
} else {
9394
return vscode.LanguageModelChatMessage.Assistant(c.content);
9495
}
9596
});
9697

97-
const result = await copilotRequest(
98-
request.model.family,
99-
messages,
100-
{},
101-
token,
102-
stream
103-
);
98+
const tools = vscode.lm.tools.filter(t => request.toolReferences.some(r => r.name === t.name));
99+
100+
const doRequest = (tools: vscode.LanguageModelToolInformation[] = []) => {
101+
return copilotRequest(
102+
request.model.family,
103+
messages,
104+
{
105+
tools,
106+
toolMode: vscode.LanguageModelChatToolMode.Required
107+
},
108+
token,
109+
stream
110+
);
111+
}
112+
113+
let result = await doRequest(tools);
114+
115+
if (result.toolCalls.length > 0) {
116+
for (const toolcall of result.toolCalls) {
117+
if (toolcall.name === RUN_SQL_TOOL_ID) {
118+
const result = await vscode.lm.invokeTool(toolcall.name, {toolInvocationToken: request.toolInvocationToken, input: toolcall.input});
119+
const resultOut = result.content.map(c => {
120+
if (c instanceof vscode.LanguageModelTextPart) {
121+
return c.value;
122+
}
123+
}).filter(c => c !== undefined).join("\n\n");
124+
125+
messages = [
126+
vscode.LanguageModelChatMessage.User(`Please review and summarize the following result set:\n\n${resultOut}\n\nThe original user request was: ${request}`)
127+
];
128+
}
129+
}
130+
131+
result = await doRequest();
132+
}
104133

105134
return { metadata: { command: "build", followUps: contextItems.followUps, statement: result.sqlCodeBlock } };
106135
}
@@ -131,10 +160,12 @@ export function activateChat(context: vscode.ExtensionContext) {
131160
}
132161

133162
context.subscriptions.push(chat);
163+
registerSqlRunTool(context);
134164
}
135165

136166
interface Result {
137167
output: string;
168+
toolCalls: vscode.LanguageModelToolCallPart[];
138169
sqlCodeBlock?: string;
139170
}
140171

@@ -144,18 +175,25 @@ async function copilotRequest(
144175
options: vscode.LanguageModelChatRequestOptions,
145176
token: vscode.CancellationToken,
146177
stream: vscode.ChatResponseStream
147-
): Promise<Result|undefined> {
178+
): Promise<Result | undefined> {
148179
const models = await vscode.lm.selectChatModels({ family: model });
149180
if (models.length > 0) {
150181
const [first] = models;
182+
options.justification = `Doing cool stuff`
151183
const response = await first.sendRequest(messages, options, token);
152-
let result: Result = {
153-
output: "",
154-
}
155184

156-
for await (const fragment of response.text) {
157-
stream.markdown(fragment);
158-
result.output += fragment;
185+
const result: Result = {
186+
output: "",
187+
toolCalls: []
188+
};
189+
190+
for await (const fragment of response.stream) {
191+
if (fragment instanceof vscode.LanguageModelTextPart) {
192+
stream.markdown(fragment.value);
193+
result.output += fragment.value;
194+
} else if (fragment instanceof vscode.LanguageModelToolCallPart) {
195+
result.toolCalls.push(fragment);
196+
}
159197
}
160198

161199
const codeBlockStart = result.output.indexOf("```sql");

src/aiProviders/copilot/sqlTool.ts

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import { CancellationToken, ExtensionContext, LanguageModelTextPart, LanguageModelTool, LanguageModelToolInvocationOptions, LanguageModelToolInvocationPrepareOptions, LanguageModelToolResult, lm, MarkdownString } from "vscode";
2+
import { JobManager } from "../../config";
3+
4+
interface IRunInTerminalParameters {
5+
statement: string;
6+
}
7+
8+
export const RUN_SQL_TOOL_ID = 'vscode-db2i-chat-sqlRunnerTool';
9+
10+
export function registerSqlRunTool(context: ExtensionContext) {
11+
context.subscriptions.push(lm.registerTool(RUN_SQL_TOOL_ID, new RunSqlTool()));
12+
}
13+
14+
class RunSqlTool
15+
implements LanguageModelTool<IRunInTerminalParameters> {
16+
async invoke(
17+
options: LanguageModelToolInvocationOptions<IRunInTerminalParameters>,
18+
_token: CancellationToken
19+
) {
20+
const params = options.input as IRunInTerminalParameters;
21+
22+
let trimmed = params.statement.trim();
23+
24+
if (trimmed.endsWith(`;`)) {
25+
trimmed = trimmed.slice(0, -1);
26+
}
27+
28+
const result = await JobManager.runSQL(trimmed);
29+
30+
return new LanguageModelToolResult([new LanguageModelTextPart(JSON.stringify(result))]);
31+
}
32+
33+
async prepareInvocation(
34+
options: LanguageModelToolInvocationPrepareOptions<IRunInTerminalParameters>,
35+
_token: CancellationToken
36+
) {
37+
const confirmationMessages = {
38+
title: 'Run SQL statement',
39+
message: new MarkdownString(
40+
`Run this statement in your job?` +
41+
`\n\n\`\`\`sql\n${options.input.statement}\n\`\`\`\n`
42+
),
43+
};
44+
45+
return {
46+
invocationMessage: `Running statement`,
47+
confirmationMessages,
48+
};
49+
}
50+
}

0 commit comments

Comments
 (0)