Skip to content

Commit ffffa76

Browse files
committed
Fetch related items when getting info about one object
Signed-off-by: worksofliam <[email protected]>
1 parent 5482125 commit ffffa76

File tree

3 files changed

+66
-14
lines changed

3 files changed

+66
-14
lines changed

src/aiProviders/context.ts

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ function splitUpUserInput(input: string): string[] {
167167
*
168168
* @param {string} input - A string that may contain table references.
169169
*/
170-
export async function getSqlContextItems(input: string): Promise<ContextDefinition[]> {
170+
export async function getSqlContextItems(input: string): Promise<{items: ContextDefinition[], refs: ResolvedSqlObject[]}> {
171171
// Parse all SCHEMA.TABLE references first
172172
const tokens = splitUpUserInput(input);
173173

@@ -201,30 +201,39 @@ export async function getSqlContextItems(input: string): Promise<ContextDefiniti
201201

202202
const allObjects = await Schemas.resolveObjects(possibleRefs);
203203

204-
const contextItems = (await Promise.all(
204+
const contextItems = await getContentItemsForRefs(allObjects);
205+
206+
return {
207+
items: contextItems,
208+
refs: allObjects,
209+
};
210+
}
211+
212+
export async function getContentItemsForRefs(allObjects: ResolvedSqlObject[]): Promise<ContextDefinition[]> {
213+
const items: (ContextDefinition|undefined)[] = await Promise.all(
205214
allObjects.map(async (o) => {
206215
try {
207216
if (o.sqlType === `SCHEMA`) {
208217
// TODO: maybe we want to include info about a schema here?
209218
return undefined;
210-
219+
211220
} else {
212221
const content = await Schemas.generateSQL(o.schema, o.name, o.sqlType);
213222

214223
return {
215224
id: o.name,
216225
type: o.sqlType,
217226
content: content,
218-
}
227+
};
219228
}
220229

221230
} catch (e) {
222231
return undefined;
223232
}
224233
})
225-
)).filter((item) => item !== undefined);
234+
);
226235

227-
return contextItems;
236+
return items.filter((item) => item !== undefined);
228237
}
229238

230239
/**

src/aiProviders/prompt.ts

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import { JobManager } from "../config";
22
import Configuration from "../configuration";
33
import { JobInfo } from "../connection/manager";
4-
import { buildSchemaDefinition, canTalkToDb, getSqlContextItems } from "./context";
4+
import Schemas from "../database/schemas";
5+
import Statement from "../database/statement";
6+
import { buildSchemaDefinition, canTalkToDb, getContentItemsForRefs, getSqlContextItems } from "./context";
57
import { DB2_SYSTEM_PROMPT } from "./prompts";
68

79
export interface PromptOptions {
@@ -48,21 +50,39 @@ export async function buildPrompt(input: string, options: PromptOptions = {}): P
4850
// TODO: self?
4951

5052
progress(`Finding objects to work with...`);
51-
const refs = await getSqlContextItems(input);
53+
const context = await getSqlContextItems(input);
5254

5355
if (options.history) {
5456
contextItems.push(...options.history);
5557
}
5658

57-
for (const table of refs) {
59+
for (const sqlObj of context.items) {
5860
contextItems.push({
59-
name: `table definition for ${table.id}`,
60-
content: table.content,
61-
description: `${table.type} definition`,
61+
name: `${sqlObj.type.toLowerCase()} definition for ${sqlObj.id}`,
62+
content: sqlObj.content,
63+
description: `${sqlObj.type} definition`,
6264
type: `assistant`
6365
});
6466
}
6567

68+
// If the user only requests one reference, then let's find related objects
69+
if (context.refs.length === 1) {
70+
const ref = context.refs[0];
71+
progress(`Finding objects related to ${Statement.prettyName(ref.name)}...`);
72+
73+
const relatedObjects = await Schemas.getRelatedObjects(ref);
74+
const contentItems = await getContentItemsForRefs(relatedObjects);
75+
76+
for (const sqlObj of contentItems) {
77+
contextItems.push({
78+
name: `${sqlObj.type.toLowerCase()} definition for ${sqlObj.id}`,
79+
content: sqlObj.content,
80+
description: `${sqlObj.type} definition`,
81+
type: `assistant`
82+
});
83+
}
84+
}
85+
6686
if (!options.history) {
6787
contextItems.push({
6888
name: `system prompt`,

src/database/schemas.ts

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ function getFilterClause(againstColumn: string, filter: string, noAnd?: boolean)
4646
};
4747
}
4848

49+
export interface ObjectReference {name: string, schema?: string};
50+
4951
const BASE_RESOLVE_SELECT = [
5052
`select `,
5153
`OBJLONGNAME as name, `,
@@ -60,7 +62,7 @@ export default class Schemas {
6062
/**
6163
* Resolves to the following SQL types: SCHEMA, TABLE, VIEW, ALIAS, INDEX, FUNCTION and PROCEDURE
6264
*/
63-
static async resolveObjects(sqlObjects: {name: string, schema?: string}[]): Promise<ResolvedSqlObject[]> {
65+
static async resolveObjects(sqlObjects: ObjectReference[]): Promise<ResolvedSqlObject[]> {
6466
let statements: string[] = [];
6567
let parameters: BasicColumnType[] = [];
6668

@@ -111,7 +113,7 @@ export default class Schemas {
111113
const query = `${statements.join(" UNION ALL ")}`;
112114
const objects: any[] = await JobManager.runSQL(query, { parameters });
113115

114-
const resolvedObjects: ResolvedSqlObject[] = objects.map(object => ({
116+
let resolvedObjects: ResolvedSqlObject[] = objects.map(object => ({
115117
name: object.NAME,
116118
schema: object.SCHEMA,
117119
sqlType: object.SQLTYPE
@@ -120,6 +122,27 @@ export default class Schemas {
120122
return resolvedObjects;
121123
}
122124

125+
static async getRelatedObjects(object: ResolvedSqlObject): Promise<ResolvedSqlObject[]> {
126+
const sql = [
127+
`with refs as (`,
128+
` SELECT `,
129+
` schema_name as schema, `,
130+
` sql_name as name, `,
131+
` case when sql_object_type = 'FOREIGN KEY' then 'TABLE' else sql_object_type end as type`,
132+
` FROM TABLE(SYSTOOLS.RELATED_OBJECTS(?, ?))`,
133+
`)`,
134+
`select * from refs `,
135+
`where type in ('TABLE', 'FUNCTION', 'PROCEDURE')`,
136+
].join(` `);
137+
138+
const related: any[] = await JobManager.runSQL(sql, { parameters: [object.schema, object.name] });
139+
return related.map(item => ({
140+
name: item.NAME,
141+
schema: item.SCHEMA,
142+
sqlType: item.TYPE
143+
}));
144+
}
145+
123146
/**
124147
* @param schema Not user input
125148
*/

0 commit comments

Comments
 (0)