@@ -69,15 +69,13 @@ export function generateZodSchemaVariableStatement({
6969 | ts . CallExpression
7070 | ts . Identifier
7171 | ts . PropertyAccessExpression
72+ | ts . ArrowFunction
7273 | undefined ;
7374 let dependencies : string [ ] = [ ] ;
7475 let requiresImport = false ;
7576
7677 if ( ts . isInterfaceDeclaration ( node ) ) {
7778 let schemaExtensionClauses : string [ ] | undefined ;
78- if ( node . typeParameters ) {
79- throw new Error ( "Interface with generics are not supported!" ) ;
80- }
8179 if ( node . heritageClauses ) {
8280 // Looping on heritageClauses browses the "extends" keywords
8381 schemaExtensionClauses = node . heritageClauses . reduce (
@@ -111,9 +109,6 @@ export function generateZodSchemaVariableStatement({
111109 }
112110
113111 if ( ts . isTypeAliasDeclaration ( node ) ) {
114- if ( node . typeParameters ) {
115- throw new Error ( "Type with generics are not supported!" ) ;
116- }
117112 const jsDocTags = skipParseJSDoc ? { } : getJSDocTags ( node , sourceFile ) ;
118113
119114 schema = buildZodPrimitive ( {
@@ -133,6 +128,39 @@ export function generateZodSchemaVariableStatement({
133128 requiresImport = true ;
134129 }
135130
131+ // process generic dependencies
132+ if ( ts . isInterfaceDeclaration ( node ) || ts . isTypeAliasDeclaration ( node ) ) {
133+ if ( schema !== undefined && node . typeParameters ) {
134+ const genericTypes = node
135+ . typeParameters . map ( ( p ) => `${ p . name . escapedText } ` )
136+ const genericDependencies = genericTypes . map ( ( p ) => getDependencyName ( p ) )
137+ dependencies = dependencies
138+ . filter ( ( dep ) => ! genericDependencies . includes ( dep ) ) ;
139+ schema = f . createArrowFunction (
140+ undefined ,
141+ genericTypes . map ( ( dep ) => f . createIdentifier ( dep ) ) ,
142+ genericTypes . map ( ( dep ) => f . createParameterDeclaration (
143+ undefined ,
144+ undefined ,
145+ undefined ,
146+ f . createIdentifier ( getDependencyName ( dep ) ) ,
147+ undefined ,
148+ f . createTypeReferenceNode (
149+ f . createQualifiedName (
150+ f . createIdentifier ( zodImportValue ) ,
151+ f . createIdentifier ( `ZodSchema<${ dep } >` )
152+ ) ,
153+ undefined
154+ ) ,
155+ undefined
156+ ) ) ,
157+ undefined ,
158+ f . createToken ( ts . SyntaxKind . EqualsGreaterThanToken ) ,
159+ schema
160+ ) ;
161+ }
162+ }
163+
136164 return {
137165 dependencies : uniq ( dependencies ) ,
138166 statement : f . createVariableStatement (
@@ -205,7 +233,37 @@ function buildZodProperties({
205233 return properties ;
206234}
207235
208- function buildZodPrimitive ( {
236+ // decorate builder to allow for schema appending/overriding
237+ function buildZodPrimitive ( opts : {
238+ z : string ;
239+ typeNode : ts . TypeNode ;
240+ isOptional : boolean ;
241+ isNullable ?: boolean ;
242+ isPartial ?: boolean ;
243+ isRequired ?: boolean ;
244+ jsDocTags : JSDocTags ;
245+ sourceFile : ts . SourceFile ;
246+ dependencies : string [ ] ;
247+ getDependencyName : ( identifierName : string ) => string ;
248+ skipParseJSDoc : boolean ;
249+ } ) {
250+ const schema = opts . jsDocTags . schema
251+ delete opts . jsDocTags . schema
252+ const generatedSchema = _buildZodPrimitive ( opts ) ;
253+ // schema not specified? return generated one
254+ if ( ! schema ) return generatedSchema ;
255+ // schema starts with dot? append it
256+ if ( schema . startsWith ( "." ) ) {
257+ return f . createPropertyAccessExpression ( generatedSchema , f . createIdentifier ( schema . slice ( 1 ) ) ) ;
258+ }
259+ // otherwise use schema verbatim
260+ return f . createPropertyAccessExpression (
261+ f . createIdentifier ( opts . z ) ,
262+ f . createIdentifier ( schema )
263+ ) ;
264+ }
265+
266+ function _buildZodPrimitive ( {
209267 z,
210268 typeNode,
211269 isOptional,
@@ -493,6 +551,18 @@ function buildZodPrimitive({
493551
494552 const nodes = typeNode . types . filter ( isNotNull ) ;
495553
554+ // string-only enum? issue z.enum
555+ if ( typeNode . types . every ( ( i ) =>
556+ ts . isLiteralTypeNode ( i ) && i . literal . kind === ts . SyntaxKind . StringLiteral
557+ ) ) {
558+ return buildZodSchema (
559+ z ,
560+ "enum" ,
561+ [ f . createArrayLiteralExpression ( nodes . map ( ( i ) => ( i as any ) [ "literal" ] ) ) ] ,
562+ zodProperties
563+ ) ;
564+ }
565+
496566 // type A = | 'b' is a valid typescript definition
497567 // Zod does not allow `z.union(['b']), so we have to return just the value
498568 if ( nodes . length === 1 ) {
@@ -530,10 +600,16 @@ function buildZodPrimitive({
530600 } ) ;
531601 }
532602
603+ // discriminator specified? use discriminatedUnion
533604 return buildZodSchema (
534605 z ,
535- "union" ,
536- [ f . createArrayLiteralExpression ( values ) ] ,
606+ jsDocTags . discriminator !== undefined
607+ ? "discriminatedUnion"
608+ : "union" ,
609+ jsDocTags . discriminator !== undefined
610+ ? [ f . createStringLiteral ( jsDocTags . discriminator ) ,
611+ f . createArrayLiteralExpression ( values ) ]
612+ : [ f . createArrayLiteralExpression ( values ) ] ,
537613 zodProperties
538614 ) ;
539615 }
@@ -726,6 +802,45 @@ function buildZodPrimitive({
726802 ) ;
727803 }
728804
805+ /*
806+ // TRPC procedures? how to iterate over interface methods?
807+ if (ts.isFunctionTypeNode(typeNode)) {
808+ let exp = f.createPropertyAccessExpression(f.createIdentifier("t.procedure"), f.createIdentifier("input"));
809+ exp = f.createCallExpression(exp, undefined, typeNode.parameters.map((p) =>
810+ buildZodPrimitive({
811+ z,
812+ typeNode:
813+ p.type || f.createKeywordTypeNode(ts.SyntaxKind.AnyKeyword),
814+ jsDocTags,
815+ sourceFile,
816+ dependencies,
817+ getDependencyName,
818+ isOptional: Boolean(p.questionToken),
819+ skipParseJSDoc,
820+ })
821+ ));
822+ exp = f.createPropertyAccessExpression(exp, f.createIdentifier("output"));
823+ exp = f.createCallExpression(exp, undefined, [
824+ buildZodPrimitive({
825+ z,
826+ typeNode: typeNode.type,
827+ jsDocTags,
828+ sourceFile,
829+ dependencies,
830+ getDependencyName,
831+ isOptional: false,
832+ skipParseJSDoc,
833+ }),
834+ ]);
835+ exp = f.createPropertyAccessExpression(exp, f.createIdentifier("query"));
836+ exp = f.createCallExpression(exp, undefined, [
837+ // f.createIdentifier(`({ ctx, input }) => { throw new TRPCError({ code: "NOT_FOUND", cause: { ctx, input } }) }`)
838+ f.createIdentifier(`({ ctx, input }) => { }`)
839+ ]);
840+ return exp;
841+ }
842+ */
843+
729844 if ( ts . isIndexedAccessTypeNode ( typeNode ) ) {
730845 return buildSchemaReference ( {
731846 node : typeNode ,
0 commit comments