@@ -88,10 +88,8 @@ func (c *Converter) AddType(input interface{}) {
8888 return
8989 }
9090
91- data := c .convertStructTopLevel (t )
92- order := c .structs
93- c .outputs [name ] = entry {order , data }
94- c .structs = order + 1
91+ data , selfRef := c .convertStructTopLevel (t )
92+ c .addSchema (name , data , selfRef )
9593}
9694
9795// Convert returns zod schema corresponding to a struct type. Its a shorthand for
@@ -143,8 +141,9 @@ var typeMapping = map[reflect.Kind]string{
143141}
144142
145143type entry struct {
146- order int
147- data string
144+ order int
145+ data string
146+ selfRef bool
148147}
149148
150149type byOrder []entry
@@ -170,12 +169,12 @@ type Converter struct {
170169 stack []meta
171170}
172171
173- func (c * Converter ) addSchema (name string , data string ) {
172+ func (c * Converter ) addSchema (name string , data string , selfRef bool ) {
174173 // First check if the object already exists. If it does do not replace. This is needed for second order
175174 _ , ok := c .outputs [name ]
176175 if ! ok {
177176 order := c .structs
178- c .outputs [name ] = entry {order , data }
177+ c .outputs [name ] = entry {order , data , selfRef }
179178 c .structs = order + 1
180179 }
181180}
@@ -203,6 +202,10 @@ func schemaName(prefix, name string) string {
203202 return fmt .Sprintf ("%s%sSchema" , prefix , name )
204203}
205204
205+ func shapeName (prefix , name string ) string {
206+ return schemaName (prefix , name ) + "Shape"
207+ }
208+
206209func fieldName (input reflect.StructField ) string {
207210 if json := input .Tag .Get ("json" ); json != "" {
208211 args := strings .Split (json , "," )
@@ -237,7 +240,7 @@ func typeName(t reflect.Type) string {
237240 return "UNKNOWN"
238241}
239242
240- func (c * Converter ) convertStructTopLevel (t reflect.Type ) string {
243+ func (c * Converter ) convertStructTopLevel (t reflect.Type ) ( string , bool ) {
241244 output := strings.Builder {}
242245
243246 name := typeName (t )
@@ -248,11 +251,16 @@ func (c *Converter) convertStructTopLevel(t reflect.Type) string {
248251
249252 top := c .stack [len (c .stack )- 1 ]
250253 if top .selfRef {
254+ shapeName := shapeName (c .prefix , name )
255+
251256 output .WriteString (fmt .Sprintf (`export type %s = %s
252257` , fullName , c .getTypeStruct (t , 0 )))
253258
259+ output .WriteString (fmt .Sprintf (`const %s = %s
260+ ` , shapeName , c .getStructShape (t , 0 )))
261+
254262 output .WriteString (fmt .Sprintf (
255- `export const %s: z.ZodType<%s> = %s ` , schemaName (c .prefix , name ), fullName , data ))
263+ `export const %s: z.ZodType<%s> = z.object(%s) ` , schemaName (c .prefix , name ), fullName , shapeName ))
256264 } else {
257265 output .WriteString (fmt .Sprintf (
258266 `export const %s = %s
@@ -265,6 +273,33 @@ func (c *Converter) convertStructTopLevel(t reflect.Type) string {
265273
266274 c .stack = c .stack [:len (c .stack )- 1 ]
267275
276+ return output .String (), top .selfRef
277+ }
278+
279+ func (c * Converter ) getStructShape (input reflect.Type , indent int ) string {
280+ output := strings.Builder {}
281+
282+ output .WriteString (`{
283+ ` )
284+
285+ fields := input .NumField ()
286+ for i := 0 ; i < fields ; i ++ {
287+ field := input .Field (i )
288+ optional := isOptional (field )
289+ nullable := isNullable (field )
290+
291+ line , shouldMerge := c .convertField (field , indent + 1 , optional , nullable )
292+
293+ if ! shouldMerge {
294+ output .WriteString (line )
295+ } else {
296+ output .WriteString (fmt .Sprintf ("%s...%s.shape,\n " , indentation (indent + 1 ), schemaName (c .prefix , typeName (field .Type ))))
297+ }
298+ }
299+
300+ output .WriteString (indentation (indent ))
301+ output .WriteString (`}` )
302+
268303 return output .String ()
269304}
270305
@@ -282,7 +317,7 @@ func (c *Converter) convertStruct(input reflect.Type, indent int) string {
282317 optional := isOptional (field )
283318 nullable := isNullable (field )
284319
285- line , shouldMerge := c .convertField (field , indent + 1 , optional , nullable , field . Anonymous )
320+ line , shouldMerge := c .convertField (field , indent + 1 , optional , nullable )
286321
287322 if ! shouldMerge {
288323 output .WriteString (line )
@@ -308,21 +343,36 @@ func (c *Converter) getTypeStruct(input reflect.Type, indent int) string {
308343 output .WriteString (`{
309344` )
310345
346+ merges := []string {}
347+
311348 fields := input .NumField ()
312349 for i := 0 ; i < fields ; i ++ {
313350 field := input .Field (i )
314351 optional := isOptional (field )
315352 nullable := isNullable (field )
316353
317- line := c .getTypeField (field , indent + 1 , optional , nullable )
354+ line , shouldMerge := c .getTypeField (field , indent + 1 , optional , nullable )
318355
319- output .WriteString (line )
356+ if ! shouldMerge {
357+ output .WriteString (line )
358+ } else {
359+ merges = append (merges , line )
360+ }
320361 }
321362
322363 output .WriteString (indentation (indent ))
323364 output .WriteString (`}` )
324365
325- return output .String ()
366+ if len (merges ) == 0 {
367+ return output .String ()
368+ }
369+
370+ newOutput := strings.Builder {}
371+ for _ , merge := range merges {
372+ newOutput .WriteString (fmt .Sprintf ("%s & " , merge ))
373+ }
374+ newOutput .WriteString (output .String ())
375+ return newOutput .String ()
326376}
327377
328378var matchGenericTypeName = regexp .MustCompile (`(.+)\[(.+)]` )
@@ -400,7 +450,8 @@ func (c *Converter) ConvertType(t reflect.Type, validate string, indent int) str
400450 } else {
401451 // throws panic if there is a cycle
402452 detectCycle (name , c .stack )
403- c .addSchema (name , c .convertStructTopLevel (t ))
453+ data , selfRef := c .convertStructTopLevel (t )
454+ c .addSchema (name , data , selfRef )
404455 validateStr .WriteString (schemaName (c .prefix , name ))
405456 }
406457 }
@@ -487,7 +538,7 @@ func (c *Converter) getType(t reflect.Type, indent int) string {
487538 return zodType
488539}
489540
490- func (c * Converter ) convertField (f reflect.StructField , indent int , optional , nullable , anonymous bool ) (string , bool ) {
541+ func (c * Converter ) convertField (f reflect.StructField , indent int , optional , nullable bool ) (string , bool ) {
491542 name := fieldName (f )
492543
493544 // fields named `-` are not exported to JSON so don't export zod types
@@ -510,7 +561,7 @@ func (c *Converter) convertField(f reflect.StructField, indent int, optional, nu
510561 }
511562
512563 t := c .ConvertType (f .Type , f .Tag .Get ("validate" ), indent )
513- if ! anonymous {
564+ if ! f . Anonymous {
514565 return fmt .Sprintf (
515566 "%s%s: %s%s%s,\n " ,
516567 indentation (indent ),
@@ -519,16 +570,23 @@ func (c *Converter) convertField(f reflect.StructField, indent int, optional, nu
519570 optionalCall ,
520571 nullableCall ), false
521572 } else {
573+ typeName := typeName (f .Type )
574+ entry , ok := c .outputs [typeName ]
575+ if ok && entry .selfRef {
576+ // Since we are spreading shape, we won't be able to support any validation tags on the embedded field
577+ return fmt .Sprintf ("%s...%s,\n " , indentation (indent ), shapeName (c .prefix , typeName )), false
578+ }
579+
522580 return fmt .Sprintf (".merge(%s)" , t ), true
523581 }
524582}
525583
526- func (c * Converter ) getTypeField (f reflect.StructField , indent int , optional , nullable bool ) string {
584+ func (c * Converter ) getTypeField (f reflect.StructField , indent int , optional , nullable bool ) ( string , bool ) {
527585 name := fieldName (f )
528586
529587 // fields named `-` are not exported to JSON so don't export types
530588 if name == "-" {
531- return ""
589+ return "" , false
532590 }
533591
534592 // because nullability is processed before custom types, this makes sure
@@ -547,14 +605,18 @@ func (c *Converter) getTypeField(f reflect.StructField, indent int, optional, nu
547605 nullableCall = " | null"
548606 }
549607
550- return fmt .Sprintf (
551- "%s%s%s: %s%s%s,\n " ,
552- indentation (indent ),
553- name ,
554- optionalCallPre ,
555- c .getType (f .Type , indent ),
556- nullableCall ,
557- optionalCallUndef )
608+ if ! f .Anonymous {
609+ return fmt .Sprintf (
610+ "%s%s%s: %s%s%s,\n " ,
611+ indentation (indent ),
612+ name ,
613+ optionalCallPre ,
614+ c .getType (f .Type , indent ),
615+ nullableCall ,
616+ optionalCallUndef ), false
617+ }
618+
619+ return typeName (f .Type ), true
558620}
559621
560622func (c * Converter ) convertSliceAndArray (t reflect.Type , validate string , indent int ) string {
@@ -885,27 +947,27 @@ func (c *Converter) validateString(validate string) string {
885947 // const FishEnum = z.enum(["Salmon", "Tuna", "Trout"]);
886948 validateStr .WriteString (fmt .Sprintf (".enum([\" %s\" ] as const)" , strings .Join (vals , "\" , \" " )))
887949 case "len" :
888- validateStr . WriteString ( fmt .Sprintf (".length(%s)" , valValue ))
950+ refines = append ( refines , fmt .Sprintf (".refine((val) => [...val].length === %s, 'String must contain %s character(s)')" , valValue , valValue ))
889951 case "min" :
890- validateStr . WriteString ( fmt .Sprintf (".min(%s)" , valValue ))
952+ refines = append ( refines , fmt .Sprintf (".refine((val) => [...val].length >= %s, 'String must contain at least %s character(s)')" , valValue , valValue ))
891953 case "max" :
892- validateStr . WriteString ( fmt .Sprintf (".max(%s)" , valValue ))
954+ refines = append ( refines , fmt .Sprintf (".refine((val) => [...val].length <= %s, 'String must contain at most %s character(s)')" , valValue , valValue ))
893955 case "gt" :
894956 val , err := strconv .Atoi (valValue )
895957 if err != nil {
896958 panic ("gt= must be followed by a number" )
897959 }
898- validateStr . WriteString ( fmt .Sprintf (".min(%d)" , val + 1 ))
960+ refines = append ( refines , fmt .Sprintf (".refine((val) => [...val].length > %d, 'String must contain at least %d character(s)')" , val , val + 1 ))
899961 case "gte" :
900- validateStr . WriteString ( fmt .Sprintf (".min(%s)" , valValue ))
962+ refines = append ( refines , fmt .Sprintf (".refine((val) => [...val].length >= %s, 'String must contain at least %s character(s)')" , valValue , valValue ))
901963 case "lt" :
902964 val , err := strconv .Atoi (valValue )
903965 if err != nil {
904966 panic ("lt= must be followed by a number" )
905967 }
906- validateStr . WriteString ( fmt .Sprintf (".max(%d)" , val - 1 ))
968+ refines = append ( refines , fmt .Sprintf (".refine((val) => [...val].length < %d, 'String must contain at most %d character(s)')" , val , val - 1 ))
907969 case "lte" :
908- validateStr . WriteString ( fmt .Sprintf (".max(%s)" , valValue ))
970+ refines = append ( refines , fmt .Sprintf (".refine((val) => [...val].length <= %s, 'String must contain at most %s character(s)')" , valValue , valValue ))
909971 case "contains" :
910972 validateStr .WriteString (fmt .Sprintf (".includes(\" %s\" )" , valValue ))
911973 case "endswith" :
0 commit comments