Skip to content

Commit 47074d7

Browse files
Enhance AST code generation with runtime schema support
- Add generateTsAstCodeFromPgAstWithSchema function that uses runtime schema to determine node wrapping - Implement field-based node type detection for complex AST structures - Add comprehensive tests for wrapped vs unwrapped node generation - Support correct builder paths: t.ast.*() for unwrapped nodes, t.nodes.*() for wrapped nodes - Maintain backward compatibility with existing generateTsAstCodeFromPgAst function Co-Authored-By: Dan Lynch <[email protected]>
1 parent 35da98b commit 47074d7

File tree

3 files changed

+249
-2
lines changed

3 files changed

+249
-2
lines changed

packages/proto-parser/__tests__/__snapshots__/meta.test.ts.snap

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,34 @@ exports[`AST to AST to create AST — meta 🤯 2`] = `
5050
})"
5151
`;
5252

53+
exports[`Complex AST with runtime schema — mixed wrapped/unwrapped patterns 1`] = `
54+
"t.nodes.selectStmt({
55+
targetList: [t.nodes.resTarget({
56+
val: t.nodes.columnRef({
57+
fields: [t.nodes.aStar({})]
58+
})
59+
})],
60+
limitOption: "LIMIT_OPTION_DEFAULT",
61+
withClause: t.nodes.withClause({
62+
ctes: [t.nodes.commonTableExpr({
63+
ctename: "test_cte",
64+
ctequery: t.nodes.selectStmt({
65+
targetList: [t.nodes.resTarget({
66+
val: t.nodes.columnRef({
67+
fields: [t.nodes.string({
68+
sval: "id"
69+
})]
70+
})
71+
})],
72+
limitOption: "LIMIT_OPTION_DEFAULT"
73+
})
74+
})],
75+
recursive: false
76+
}),
77+
op: "SETOP_NONE"
78+
})"
79+
`;
80+
5381
exports[`Complex AST — Advanced SQL with CTEs, Window Functions, Joins, and Subqueries 1`] = `
5482
{
5583
"SelectStmt": {
@@ -1282,3 +1310,53 @@ exports[`Complex AST — Advanced SQL with CTEs, Window Functions, Joins, and Su
12821310
op: "SETOP_NONE"
12831311
})"
12841312
`;
1313+
1314+
exports[`Enhanced AST generation with runtime schemawrapped vs unwrapped nodes 1`] = `
1315+
{
1316+
"SelectStmt": {
1317+
"fromClause": [
1318+
{
1319+
"RangeVar": {
1320+
"inh": true,
1321+
"relname": "test_table",
1322+
"relpersistence": "p",
1323+
},
1324+
},
1325+
],
1326+
"limitOption": "LIMIT_OPTION_DEFAULT",
1327+
"op": "SETOP_NONE",
1328+
"targetList": [
1329+
{
1330+
"ResTarget": {
1331+
"val": {
1332+
"ColumnRef": {
1333+
"fields": [
1334+
{
1335+
"A_Star": {},
1336+
},
1337+
],
1338+
},
1339+
},
1340+
},
1341+
},
1342+
],
1343+
},
1344+
}
1345+
`;
1346+
1347+
exports[`Enhanced AST generation with runtime schemawrapped vs unwrapped nodes 2`] = `
1348+
"t.nodes.selectStmt({
1349+
targetList: [t.nodes.resTarget({
1350+
val: t.nodes.columnRef({
1351+
fields: [t.nodes.aStar({})]
1352+
})
1353+
})],
1354+
fromClause: [t.nodes.rangeVar({
1355+
relname: "test_table",
1356+
inh: true,
1357+
relpersistence: "p"
1358+
})],
1359+
limitOption: "LIMIT_OPTION_DEFAULT",
1360+
op: "SETOP_NONE"
1361+
})"
1362+
`;

packages/proto-parser/__tests__/meta.test.ts

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import * as t from '../test-utils/meta';
22
import { SelectStmt } from '@pgsql/types';
3-
import { generateTsAstCodeFromPgAst } from '../src/utils'
3+
import { generateTsAstCodeFromPgAst, generateTsAstCodeFromPgAstWithSchema } from '../src/utils'
4+
import { runtimeSchema } from '../test-utils/meta/runtime-schema';
45
import generate from '@babel/generator';
56

67
it('AST to AST to create AST — meta 🤯', () => {
@@ -371,4 +372,74 @@ it('Complex AST — Advanced SQL with CTEs, Window Functions, Joins, and Subquer
371372

372373
const astForComplexAst = generateTsAstCodeFromPgAst(complexSelectStmt);
373374
expect(generate(astForComplexAst).code).toMatchSnapshot();
374-
});
375+
});
376+
377+
it('Enhanced AST generation with runtime schema — wrapped vs unwrapped nodes', () => {
378+
const selectStmt = t.nodes.selectStmt({
379+
targetList: [
380+
t.nodes.resTarget({
381+
val: t.nodes.columnRef({
382+
fields: [t.nodes.aStar()]
383+
})
384+
})
385+
],
386+
fromClause: [
387+
t.nodes.rangeVar({
388+
relname: 'test_table',
389+
inh: true,
390+
relpersistence: 'p'
391+
})
392+
],
393+
limitOption: 'LIMIT_OPTION_DEFAULT',
394+
op: 'SETOP_NONE'
395+
});
396+
397+
expect(selectStmt).toMatchSnapshot();
398+
399+
const enhancedAst = generateTsAstCodeFromPgAstWithSchema(selectStmt, runtimeSchema);
400+
const generatedCode = generate(enhancedAst).code;
401+
402+
expect(generatedCode).toMatchSnapshot();
403+
404+
expect(generatedCode).toContain('t.nodes.selectStmt');
405+
expect(generatedCode).toContain('t.nodes.resTarget');
406+
});
407+
408+
it('Complex AST with runtime schema — mixed wrapped/unwrapped patterns', () => {
409+
const complexStmt = t.nodes.selectStmt({
410+
withClause: t.ast.withClause({
411+
ctes: [
412+
t.nodes.commonTableExpr({
413+
ctename: 'test_cte',
414+
ctequery: t.nodes.selectStmt({
415+
targetList: [
416+
t.nodes.resTarget({
417+
val: t.nodes.columnRef({
418+
fields: [t.nodes.string({ sval: 'id' })]
419+
})
420+
})
421+
],
422+
limitOption: 'LIMIT_OPTION_DEFAULT'
423+
})
424+
})
425+
],
426+
recursive: false
427+
}),
428+
targetList: [
429+
t.nodes.resTarget({
430+
val: t.nodes.columnRef({
431+
fields: [t.nodes.aStar()]
432+
})
433+
})
434+
],
435+
limitOption: 'LIMIT_OPTION_DEFAULT',
436+
op: 'SETOP_NONE'
437+
});
438+
439+
const enhancedAst = generateTsAstCodeFromPgAstWithSchema(complexStmt, runtimeSchema);
440+
const generatedCode = generate(enhancedAst).code;
441+
442+
expect(generatedCode).toMatchSnapshot();
443+
expect(generatedCode).toContain('t.nodes.withClause');
444+
expect(generatedCode).toContain('t.nodes.selectStmt');
445+
});

packages/proto-parser/src/utils/meta.ts

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { toSpecialCamelCase } from './index';
22
import * as t from '@babel/types';
3+
import { NodeSpec, FieldSpec } from '../runtime-schema/types';
34

45
/**
56
* Converts an AST (Abstract Syntax Tree) representation of a SQL query into
@@ -56,4 +57,101 @@ export function generateTsAstCodeFromPgAst(ast: any): any {
5657

5758
return traverse(ast);
5859
}
60+
61+
export function generateTsAstCodeFromPgAstWithSchema(ast: any, runtimeSchema: NodeSpec[]): any {
62+
const schemaMap = new Map<string, NodeSpec>();
63+
runtimeSchema.forEach(spec => {
64+
schemaMap.set(spec.name, spec);
65+
});
66+
67+
function createAstNode(functionName: string, properties: any, isWrapped: boolean = true) {
68+
const args = properties.map(([propKey, propValue]: [string, any]) => {
69+
return t.objectProperty(t.identifier(propKey), getValueNode(propValue));
70+
});
71+
72+
const builderPath = isWrapped ? 'nodes' : 'ast';
73+
return t.callExpression(
74+
t.memberExpression(
75+
t.memberExpression(t.identifier('t'), t.identifier(builderPath)),
76+
t.identifier(functionName)
77+
),
78+
[t.objectExpression(args)]
79+
);
80+
}
81+
82+
function getValueNode(value: any): t.Expression {
83+
if (Array.isArray(value)) {
84+
return t.arrayExpression(value.map(item => getValueNode(item)));
85+
} else if (typeof value === 'object') {
86+
return value === null ? t.nullLiteral() : traverse(value);
87+
}
88+
switch (typeof value) {
89+
case 'boolean':
90+
return t.booleanLiteral(value);
91+
case 'number':
92+
return t.numericLiteral(value);
93+
case 'string':
94+
return t.stringLiteral(value);
95+
default:
96+
return t.stringLiteral(String(value)); // Fallback for other types
97+
}
98+
}
99+
100+
function findNodeTypeByFields(fieldNames: string[]): NodeSpec | null {
101+
for (const nodeSpec of runtimeSchema) {
102+
const specFieldNames = nodeSpec.fields.map(f => f.name).sort();
103+
const sortedFieldNames = [...fieldNames].sort();
104+
105+
const hasAllRequiredFields = specFieldNames.every(fieldName =>
106+
sortedFieldNames.includes(fieldName) ||
107+
nodeSpec.fields.find(f => f.name === fieldName)?.optional
108+
);
109+
const hasOnlyValidFields = sortedFieldNames.every(fieldName =>
110+
specFieldNames.includes(fieldName)
111+
);
112+
113+
if (hasAllRequiredFields && hasOnlyValidFields && sortedFieldNames.length > 0) {
114+
return nodeSpec;
115+
}
116+
}
117+
return null;
118+
}
119+
120+
function traverse(node: any): t.Expression {
121+
if (Array.isArray(node)) {
122+
return t.arrayExpression(node.map(traverse));
123+
} else if (node && typeof node === 'object') {
124+
const entries = Object.entries(node);
125+
if (entries.length === 0) return t.objectExpression([]);
126+
127+
if (entries.length === 1) {
128+
const [key, value] = entries[0];
129+
const functionName = toSpecialCamelCase(key);
130+
131+
const nodeSpec = schemaMap.get(key);
132+
const isWrapped = nodeSpec ? nodeSpec.isNode : true; // Default to wrapped if not found
133+
134+
return createAstNode(functionName, Object.entries(value), isWrapped);
135+
} else {
136+
const fieldNames = entries.map(([key]) => key);
137+
const matchingNodeSpec = findNodeTypeByFields(fieldNames);
138+
139+
if (matchingNodeSpec) {
140+
const functionName = toSpecialCamelCase(matchingNodeSpec.name);
141+
const isWrapped = matchingNodeSpec.isNode;
142+
return createAstNode(functionName, entries, isWrapped);
143+
} else {
144+
const properties = entries.map(([propKey, propValue]) => {
145+
return t.objectProperty(t.identifier(propKey), getValueNode(propValue));
146+
});
147+
return t.objectExpression(properties);
148+
}
149+
}
150+
}
151+
152+
return getValueNode(node);
153+
}
154+
155+
return traverse(ast);
156+
}
59157

0 commit comments

Comments
 (0)