Skip to content

Commit 061ba64

Browse files
committed
Follow up questions
Signed-off-by: worksofliam <[email protected]>
1 parent ffffa76 commit 061ba64

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

src/aiProviders/copilot/index.ts

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ const CHAT_ID = `vscode-db2i.chat`;
99
interface IDB2ChatResult extends vscode.ChatResult {
1010
metadata: {
1111
command: string;
12+
followUps: string[];
1213
};
1314
}
1415

@@ -84,7 +85,7 @@ export function activateChat(context: vscode.ExtensionContext) {
8485
progress: stream.progress
8586
});
8687

87-
const messages = contextItems.map(c => {
88+
const messages = contextItems.context.map(c => {
8889
if (c.type === `user`) {
8990
return vscode.LanguageModelChatMessage.User(c.content);
9091
} else {
@@ -100,7 +101,7 @@ export function activateChat(context: vscode.ExtensionContext) {
100101
stream
101102
);
102103

103-
return { metadata: { command: "build" } };
104+
return { metadata: { command: "build", followUps: contextItems.followUps } };
104105
}
105106
} else {
106107
throw new Error(
@@ -111,6 +112,22 @@ export function activateChat(context: vscode.ExtensionContext) {
111112

112113
const chat = vscode.chat.createChatParticipant(CHAT_ID, chatHandler);
113114
chat.iconPath = new vscode.ThemeIcon(`database`);
115+
chat.followupProvider = {
116+
provideFollowups(result, context, token) {
117+
const followups: vscode.ChatFollowup[] = [];
118+
119+
if (result.metadata) {
120+
for (const followup of result.metadata.followUps) {
121+
followups.push({
122+
prompt: followup,
123+
participant: CHAT_ID,
124+
});
125+
}
126+
}
127+
128+
return followups;
129+
},
130+
}
114131

115132
context.subscriptions.push(chat);
116133
}

src/aiProviders/prompt.ts

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,16 @@ export interface Db2ContextItems {
1919
specific?: "copilot"|"continue";
2020
}
2121

22-
export async function buildPrompt(input: string, options: PromptOptions = {}): Promise<Db2ContextItems[]> {
22+
export interface BuildResult {
23+
context: Db2ContextItems[];
24+
followUps: string[];
25+
}
26+
27+
export async function buildPrompt(input: string, options: PromptOptions = {}): Promise<BuildResult> {
2328
const currentJob: JobInfo = JobManager.getSelection();
29+
2430
let contextItems: Db2ContextItems[] = [];
31+
let followUps = [];
2532

2633
const progress = (message: string) => {
2734
if (options.progress) {
@@ -65,14 +72,32 @@ export async function buildPrompt(input: string, options: PromptOptions = {}): P
6572
});
6673
}
6774

75+
if (context.refs.filter(r => r.sqlType === `TABLE`).length >= 2) {
76+
const randomIndexA = Math.floor(Math.random() * context.refs.length);
77+
const randomIndexB = Math.floor(Math.random() * context.refs.length);
78+
const tableA = context.refs[randomIndexA].name;
79+
const tableB = context.refs[randomIndexB].name;
80+
81+
if (tableA !== tableB) {
82+
followUps.push(`How can I join ${tableA} and ${tableB}?`);
83+
}
84+
}
85+
6886
// If the user only requests one reference, then let's find related objects
6987
if (context.refs.length === 1) {
7088
const ref = context.refs[0];
71-
progress(`Finding objects related to ${Statement.prettyName(ref.name)}...`);
89+
const prettyNameRef = Statement.prettyName(ref.name);
90+
progress(`Finding objects related to ${prettyNameRef}...`);
7291

7392
const relatedObjects = await Schemas.getRelatedObjects(ref);
7493
const contentItems = await getContentItemsForRefs(relatedObjects);
7594

95+
if (relatedObjects.length === 1) {
96+
followUps.push(`How is ${prettyNameRef} related to ${Statement.prettyName(relatedObjects[0].name)}?`);
97+
} else if (ref.sqlType === `TABLE`) {
98+
followUps.push(`What are some objects related to that table?`);
99+
}
100+
76101
for (const sqlObj of contentItems) {
77102
contextItems.push({
78103
name: `${sqlObj.type.toLowerCase()} definition for ${sqlObj.id}`,
@@ -81,6 +106,12 @@ export async function buildPrompt(input: string, options: PromptOptions = {}): P
81106
type: `assistant`
82107
});
83108
}
109+
110+
} else if (context.refs.length > 1) {
111+
const randomRef = context.refs[Math.floor(Math.random() * context.refs.length)];
112+
const prettyNameRef = Statement.prettyName(randomRef.name);
113+
114+
followUps.push(`What are some objects related to ${prettyNameRef}?`);
84115
}
85116

86117
if (!options.history) {
@@ -100,5 +131,8 @@ export async function buildPrompt(input: string, options: PromptOptions = {}): P
100131
});
101132
}
102133

103-
return contextItems;
134+
return {
135+
context: contextItems,
136+
followUps
137+
};
104138
}

0 commit comments

Comments
 (0)