Skip to content

Commit 01bb812

Browse files
committed
More comments to document how context gathering works
Signed-off-by: worksofliam <[email protected]>
1 parent 765e927 commit 01bb812

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

src/aiProviders/context.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ export async function getSqlContextItems(input: string): Promise<{items: Db2Cont
195195
}
196196

197197
const allObjects = await Schemas.resolveObjects(possibleRefs);
198-
199198
const contextItems = await getContentItemsForRefs(allObjects);
200199

201200
return {

src/aiProviders/prompt.ts

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,52 +41,64 @@ export async function getContextItems(input: string, options: PromptOptions = {}
4141
// TODO: self?
4242

4343
progress(`Finding objects to work with...`);
44-
const context = await getSqlContextItems(input);
4544

46-
contextItems.push(...context.items);
45+
// First, let's take the user input and see if contains any references to SQL objects.
46+
// This returns a list of references to SQL objects, such as tables, views, schemas, etc,
47+
// and the context items that are related to those references.
48+
const userInput = await getSqlContextItems(input);
4749

48-
if (context.refs.filter(r => r.sqlType === `TABLE`).length >= 2) {
49-
const randomIndexA = Math.floor(Math.random() * context.refs.length);
50-
const randomIndexB = Math.floor(Math.random() * context.refs.length);
51-
const tableA = context.refs[randomIndexA].name;
52-
const tableB = context.refs[randomIndexB].name;
50+
contextItems.push(...userInput.items);
51+
52+
// If the user referenced 2 or more tables, let's add a follow up
53+
if (userInput.refs.filter(r => r.sqlType === `TABLE`).length >= 2) {
54+
const randomIndexA = Math.floor(Math.random() * userInput.refs.length);
55+
const randomIndexB = Math.floor(Math.random() * userInput.refs.length);
56+
const tableA = userInput.refs[randomIndexA].name;
57+
const tableB = userInput.refs[randomIndexB].name;
5358

5459
if (tableA !== tableB) {
5560
followUps.push(`How can I join ${tableA} and ${tableB}?`);
5661
}
5762
}
5863

59-
// If the user only requests one reference, then let's find related objects
60-
if (context.refs.length === 1) {
61-
const ref = context.refs[0];
64+
// If the user only requests one reference, then let's do something
65+
if (userInput.refs.length === 1) {
66+
const ref = userInput.refs[0];
6267
const prettyNameRef = Statement.prettyName(ref.name);
6368

6469
if (ref.sqlType === `SCHEMA`) {
70+
// If the only reference is a schema, let's just add follow ups
6571
followUps.push(
6672
`What are some objects in that schema?`,
6773
`What is the difference between a schema and a library?`,
6874
);
75+
6976
} else {
77+
// If the user referenced a table, view, or other object, let's fetch related objects
7078
progress(`Finding objects related to ${prettyNameRef}...`);
7179

7280
const relatedObjects = await Schemas.getRelatedObjects(ref);
7381
const contentItems = await getContentItemsForRefs(relatedObjects);
7482

7583
contextItems.push(...contentItems);
7684

85+
// Then also add some follow ups
7786
if (relatedObjects.length === 1) {
7887
followUps.push(`How is ${prettyNameRef} related to ${Statement.prettyName(relatedObjects[0].name)}?`);
7988
} else if (ref.sqlType === `TABLE`) {
8089
followUps.push(`What are some objects related to that table?`);
8190
}
8291
}
8392

84-
} else if (context.refs.length > 1) {
85-
const randomRef = context.refs[Math.floor(Math.random() * context.refs.length)];
93+
} else if (userInput.refs.length > 1) {
94+
// If there are multiple references, let's just add a follow up
95+
const randomRef = userInput.refs[Math.floor(Math.random() * userInput.refs.length)];
8696
const prettyNameRef = Statement.prettyName(randomRef.name);
8797

8898
followUps.push(`What are some objects related to ${prettyNameRef}?`);
99+
89100
} else if (useSchemaDef) {
101+
// If the user didn't reference any objects, but we are using schema definitions, let's just add the schema definition
90102
progress(`Getting info for schema ${currentSchema}...`);
91103
const schemaSemantic = await buildSchemaDefinition(currentSchema);
92104
if (schemaSemantic) {

0 commit comments

Comments
 (0)