Skip to content

Commit dd65445

Browse files
Fix withClause wrapping logic to use t.ast.* for specific node types
- Fixed multi-field object processing in traverse() to check parent field specs - WithClause now correctly generates t.ast.withClause instead of t.nodes.withClause - Updated snapshot test to reflect correct expected output - All 42 tests passing Co-Authored-By: Dan Lynch <[email protected]>
1 parent 47074d7 commit dd65445

File tree

3 files changed

+55
-14
lines changed

3 files changed

+55
-14
lines changed

packages/proto-parser/__tests__/__snapshots__/meta.test.ts.snap

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ exports[`Complex AST with runtime schema — mixed wrapped/unwrapped patterns 1`
5858
})
5959
})],
6060
limitOption: "LIMIT_OPTION_DEFAULT",
61-
withClause: t.nodes.withClause({
61+
withClause: t.ast.withClause({
6262
ctes: [t.nodes.commonTableExpr({
6363
ctename: "test_cte",
6464
ctequery: t.nodes.selectStmt({

packages/proto-parser/__tests__/meta.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,6 @@ it('Complex AST with runtime schema — mixed wrapped/unwrapped patterns', () =>
440440
const generatedCode = generate(enhancedAst).code;
441441

442442
expect(generatedCode).toMatchSnapshot();
443-
expect(generatedCode).toContain('t.nodes.withClause');
443+
expect(generatedCode).toContain('t.ast.withClause');
444444
expect(generatedCode).toContain('t.nodes.selectStmt');
445445
});

packages/proto-parser/src/utils/meta.ts

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ export function generateTsAstCodeFromPgAstWithSchema(ast: any, runtimeSchema: No
6666

6767
function createAstNode(functionName: string, properties: any, isWrapped: boolean = true) {
6868
const args = properties.map(([propKey, propValue]: [string, any]) => {
69+
if (propValue && typeof propValue === 'object' && propValue.type) {
70+
return t.objectProperty(t.identifier(propKey), propValue);
71+
}
6972
return t.objectProperty(t.identifier(propKey), getValueNode(propValue));
7073
});
7174

@@ -79,11 +82,21 @@ export function generateTsAstCodeFromPgAstWithSchema(ast: any, runtimeSchema: No
7982
);
8083
}
8184

82-
function getValueNode(value: any): t.Expression {
85+
function getValueNode(value: any, parentNodeType?: string, fieldName?: string): t.Expression {
8386
if (Array.isArray(value)) {
84-
return t.arrayExpression(value.map(item => getValueNode(item)));
87+
return t.arrayExpression(value.map(item => getValueNode(item, parentNodeType, fieldName)));
8588
} else if (typeof value === 'object') {
86-
return value === null ? t.nullLiteral() : traverse(value);
89+
if (value === null) return t.nullLiteral();
90+
91+
if (parentNodeType && fieldName) {
92+
const parentSpec = schemaMap.get(parentNodeType);
93+
if (parentSpec) {
94+
const fieldSpec = parentSpec.fields.find(f => f.name === fieldName);
95+
96+
}
97+
}
98+
99+
return traverse(value, parentNodeType, fieldName);
87100
}
88101
switch (typeof value) {
89102
case 'boolean':
@@ -117,9 +130,9 @@ export function generateTsAstCodeFromPgAstWithSchema(ast: any, runtimeSchema: No
117130
return null;
118131
}
119132

120-
function traverse(node: any): t.Expression {
133+
function traverse(node: any, parentNodeType?: string, fieldName?: string): t.Expression {
121134
if (Array.isArray(node)) {
122-
return t.arrayExpression(node.map(traverse));
135+
return t.arrayExpression(node.map(item => traverse(item, parentNodeType, fieldName)));
123136
} else if (node && typeof node === 'object') {
124137
const entries = Object.entries(node);
125138
if (entries.length === 0) return t.objectExpression([]);
@@ -128,28 +141,56 @@ export function generateTsAstCodeFromPgAstWithSchema(ast: any, runtimeSchema: No
128141
const [key, value] = entries[0];
129142
const functionName = toSpecialCamelCase(key);
130143

131-
const nodeSpec = schemaMap.get(key);
132-
const isWrapped = nodeSpec ? nodeSpec.isNode : true; // Default to wrapped if not found
144+
let isWrapped = true;
145+
146+
if (parentNodeType && fieldName) {
147+
const parentSpec = schemaMap.get(parentNodeType);
148+
if (parentSpec) {
149+
const fieldSpec = parentSpec.fields.find(f => f.name === fieldName);
150+
if (fieldSpec && fieldSpec.isNode && fieldSpec.type !== 'Node') {
151+
isWrapped = false;
152+
}
153+
}
154+
}
155+
156+
const processedProperties = Object.entries(value).map(([propKey, propValue]) => {
157+
return [propKey, getValueNode(propValue, key, propKey)];
158+
});
133159

134-
return createAstNode(functionName, Object.entries(value), isWrapped);
160+
return createAstNode(functionName, processedProperties, isWrapped);
135161
} else {
136162
const fieldNames = entries.map(([key]) => key);
137163
const matchingNodeSpec = findNodeTypeByFields(fieldNames);
138164

139165
if (matchingNodeSpec) {
140166
const functionName = toSpecialCamelCase(matchingNodeSpec.name);
141-
const isWrapped = matchingNodeSpec.isNode;
142-
return createAstNode(functionName, entries, isWrapped);
167+
168+
let isWrapped = true;
169+
if (parentNodeType && fieldName) {
170+
const parentSpec = schemaMap.get(parentNodeType);
171+
if (parentSpec) {
172+
const parentFieldSpec = parentSpec.fields.find(f => f.name === fieldName);
173+
if (parentFieldSpec && parentFieldSpec.isNode && parentFieldSpec.type !== 'Node') {
174+
isWrapped = false;
175+
}
176+
}
177+
}
178+
179+
const processedProperties = entries.map(([propKey, propValue]) => {
180+
return [propKey, traverse(propValue, matchingNodeSpec.name, propKey)];
181+
});
182+
183+
return createAstNode(functionName, processedProperties, isWrapped);
143184
} else {
144185
const properties = entries.map(([propKey, propValue]) => {
145-
return t.objectProperty(t.identifier(propKey), getValueNode(propValue));
186+
return t.objectProperty(t.identifier(propKey), traverse(propValue, parentNodeType, propKey));
146187
});
147188
return t.objectExpression(properties);
148189
}
149190
}
150191
}
151192

152-
return getValueNode(node);
193+
return getValueNode(node, parentNodeType, fieldName);
153194
}
154195

155196
return traverse(ast);

0 commit comments

Comments
 (0)