Skip to content

Commit 67d8f78

Browse files
authored
Fix schema generation for embedded recursive structs (#17)
* Fix schema generation for embedded recursive structs * Fix issue in getStructShape and add corresponding UT * Fix string length validation inconsistency between zod and validator (#19)
1 parent f521ea1 commit 67d8f78

File tree

2 files changed

+195
-53
lines changed

2 files changed

+195
-53
lines changed

zod.go

Lines changed: 96 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,8 @@ func (c *Converter) AddType(input interface{}) {
8888
return
8989
}
9090

91-
data := c.convertStructTopLevel(t)
92-
order := c.structs
93-
c.outputs[name] = entry{order, data}
94-
c.structs = order + 1
91+
data, selfRef := c.convertStructTopLevel(t)
92+
c.addSchema(name, data, selfRef)
9593
}
9694

9795
// Convert returns zod schema corresponding to a struct type. Its a shorthand for
@@ -143,8 +141,9 @@ var typeMapping = map[reflect.Kind]string{
143141
}
144142

145143
type entry struct {
146-
order int
147-
data string
144+
order int
145+
data string
146+
selfRef bool
148147
}
149148

150149
type byOrder []entry
@@ -170,12 +169,12 @@ type Converter struct {
170169
stack []meta
171170
}
172171

173-
func (c *Converter) addSchema(name string, data string) {
172+
func (c *Converter) addSchema(name string, data string, selfRef bool) {
174173
// First check if the object already exists. If it does do not replace. This is needed for second order
175174
_, ok := c.outputs[name]
176175
if !ok {
177176
order := c.structs
178-
c.outputs[name] = entry{order, data}
177+
c.outputs[name] = entry{order, data, selfRef}
179178
c.structs = order + 1
180179
}
181180
}
@@ -203,6 +202,10 @@ func schemaName(prefix, name string) string {
203202
return fmt.Sprintf("%s%sSchema", prefix, name)
204203
}
205204

205+
func shapeName(prefix, name string) string {
206+
return schemaName(prefix, name) + "Shape"
207+
}
208+
206209
func fieldName(input reflect.StructField) string {
207210
if json := input.Tag.Get("json"); json != "" {
208211
args := strings.Split(json, ",")
@@ -237,7 +240,7 @@ func typeName(t reflect.Type) string {
237240
return "UNKNOWN"
238241
}
239242

240-
func (c *Converter) convertStructTopLevel(t reflect.Type) string {
243+
func (c *Converter) convertStructTopLevel(t reflect.Type) (string, bool) {
241244
output := strings.Builder{}
242245

243246
name := typeName(t)
@@ -248,11 +251,16 @@ func (c *Converter) convertStructTopLevel(t reflect.Type) string {
248251

249252
top := c.stack[len(c.stack)-1]
250253
if top.selfRef {
254+
shapeName := shapeName(c.prefix, name)
255+
251256
output.WriteString(fmt.Sprintf(`export type %s = %s
252257
`, fullName, c.getTypeStruct(t, 0)))
253258

259+
output.WriteString(fmt.Sprintf(`const %s = %s
260+
`, shapeName, c.getStructShape(t, 0)))
261+
254262
output.WriteString(fmt.Sprintf(
255-
`export const %s: z.ZodType<%s> = %s`, schemaName(c.prefix, name), fullName, data))
263+
`export const %s: z.ZodType<%s> = z.object(%s)`, schemaName(c.prefix, name), fullName, shapeName))
256264
} else {
257265
output.WriteString(fmt.Sprintf(
258266
`export const %s = %s
@@ -265,6 +273,33 @@ func (c *Converter) convertStructTopLevel(t reflect.Type) string {
265273

266274
c.stack = c.stack[:len(c.stack)-1]
267275

276+
return output.String(), top.selfRef
277+
}
278+
279+
func (c *Converter) getStructShape(input reflect.Type, indent int) string {
280+
output := strings.Builder{}
281+
282+
output.WriteString(`{
283+
`)
284+
285+
fields := input.NumField()
286+
for i := 0; i < fields; i++ {
287+
field := input.Field(i)
288+
optional := isOptional(field)
289+
nullable := isNullable(field)
290+
291+
line, shouldMerge := c.convertField(field, indent+1, optional, nullable)
292+
293+
if !shouldMerge {
294+
output.WriteString(line)
295+
} else {
296+
output.WriteString(fmt.Sprintf("%s...%s.shape,\n", indentation(indent+1), schemaName(c.prefix, typeName(field.Type))))
297+
}
298+
}
299+
300+
output.WriteString(indentation(indent))
301+
output.WriteString(`}`)
302+
268303
return output.String()
269304
}
270305

@@ -282,7 +317,7 @@ func (c *Converter) convertStruct(input reflect.Type, indent int) string {
282317
optional := isOptional(field)
283318
nullable := isNullable(field)
284319

285-
line, shouldMerge := c.convertField(field, indent+1, optional, nullable, field.Anonymous)
320+
line, shouldMerge := c.convertField(field, indent+1, optional, nullable)
286321

287322
if !shouldMerge {
288323
output.WriteString(line)
@@ -308,21 +343,36 @@ func (c *Converter) getTypeStruct(input reflect.Type, indent int) string {
308343
output.WriteString(`{
309344
`)
310345

346+
merges := []string{}
347+
311348
fields := input.NumField()
312349
for i := 0; i < fields; i++ {
313350
field := input.Field(i)
314351
optional := isOptional(field)
315352
nullable := isNullable(field)
316353

317-
line := c.getTypeField(field, indent+1, optional, nullable)
354+
line, shouldMerge := c.getTypeField(field, indent+1, optional, nullable)
318355

319-
output.WriteString(line)
356+
if !shouldMerge {
357+
output.WriteString(line)
358+
} else {
359+
merges = append(merges, line)
360+
}
320361
}
321362

322363
output.WriteString(indentation(indent))
323364
output.WriteString(`}`)
324365

325-
return output.String()
366+
if len(merges) == 0 {
367+
return output.String()
368+
}
369+
370+
newOutput := strings.Builder{}
371+
for _, merge := range merges {
372+
newOutput.WriteString(fmt.Sprintf("%s & ", merge))
373+
}
374+
newOutput.WriteString(output.String())
375+
return newOutput.String()
326376
}
327377

328378
var matchGenericTypeName = regexp.MustCompile(`(.+)\[(.+)]`)
@@ -400,7 +450,8 @@ func (c *Converter) ConvertType(t reflect.Type, validate string, indent int) str
400450
} else {
401451
// throws panic if there is a cycle
402452
detectCycle(name, c.stack)
403-
c.addSchema(name, c.convertStructTopLevel(t))
453+
data, selfRef := c.convertStructTopLevel(t)
454+
c.addSchema(name, data, selfRef)
404455
validateStr.WriteString(schemaName(c.prefix, name))
405456
}
406457
}
@@ -487,7 +538,7 @@ func (c *Converter) getType(t reflect.Type, indent int) string {
487538
return zodType
488539
}
489540

490-
func (c *Converter) convertField(f reflect.StructField, indent int, optional, nullable, anonymous bool) (string, bool) {
541+
func (c *Converter) convertField(f reflect.StructField, indent int, optional, nullable bool) (string, bool) {
491542
name := fieldName(f)
492543

493544
// fields named `-` are not exported to JSON so don't export zod types
@@ -510,7 +561,7 @@ func (c *Converter) convertField(f reflect.StructField, indent int, optional, nu
510561
}
511562

512563
t := c.ConvertType(f.Type, f.Tag.Get("validate"), indent)
513-
if !anonymous {
564+
if !f.Anonymous {
514565
return fmt.Sprintf(
515566
"%s%s: %s%s%s,\n",
516567
indentation(indent),
@@ -519,16 +570,23 @@ func (c *Converter) convertField(f reflect.StructField, indent int, optional, nu
519570
optionalCall,
520571
nullableCall), false
521572
} else {
573+
typeName := typeName(f.Type)
574+
entry, ok := c.outputs[typeName]
575+
if ok && entry.selfRef {
576+
// Since we are spreading shape, we won't be able to support any validation tags on the embedded field
577+
return fmt.Sprintf("%s...%s,\n", indentation(indent), shapeName(c.prefix, typeName)), false
578+
}
579+
522580
return fmt.Sprintf(".merge(%s)", t), true
523581
}
524582
}
525583

526-
func (c *Converter) getTypeField(f reflect.StructField, indent int, optional, nullable bool) string {
584+
func (c *Converter) getTypeField(f reflect.StructField, indent int, optional, nullable bool) (string, bool) {
527585
name := fieldName(f)
528586

529587
// fields named `-` are not exported to JSON so don't export types
530588
if name == "-" {
531-
return ""
589+
return "", false
532590
}
533591

534592
// because nullability is processed before custom types, this makes sure
@@ -547,14 +605,18 @@ func (c *Converter) getTypeField(f reflect.StructField, indent int, optional, nu
547605
nullableCall = " | null"
548606
}
549607

550-
return fmt.Sprintf(
551-
"%s%s%s: %s%s%s,\n",
552-
indentation(indent),
553-
name,
554-
optionalCallPre,
555-
c.getType(f.Type, indent),
556-
nullableCall,
557-
optionalCallUndef)
608+
if !f.Anonymous {
609+
return fmt.Sprintf(
610+
"%s%s%s: %s%s%s,\n",
611+
indentation(indent),
612+
name,
613+
optionalCallPre,
614+
c.getType(f.Type, indent),
615+
nullableCall,
616+
optionalCallUndef), false
617+
}
618+
619+
return typeName(f.Type), true
558620
}
559621

560622
func (c *Converter) convertSliceAndArray(t reflect.Type, validate string, indent int) string {
@@ -885,27 +947,27 @@ func (c *Converter) validateString(validate string) string {
885947
// const FishEnum = z.enum(["Salmon", "Tuna", "Trout"]);
886948
validateStr.WriteString(fmt.Sprintf(".enum([\"%s\"] as const)", strings.Join(vals, "\", \"")))
887949
case "len":
888-
validateStr.WriteString(fmt.Sprintf(".length(%s)", valValue))
950+
refines = append(refines, fmt.Sprintf(".refine((val) => [...val].length === %s, 'String must contain %s character(s)')", valValue, valValue))
889951
case "min":
890-
validateStr.WriteString(fmt.Sprintf(".min(%s)", valValue))
952+
refines = append(refines, fmt.Sprintf(".refine((val) => [...val].length >= %s, 'String must contain at least %s character(s)')", valValue, valValue))
891953
case "max":
892-
validateStr.WriteString(fmt.Sprintf(".max(%s)", valValue))
954+
refines = append(refines, fmt.Sprintf(".refine((val) => [...val].length <= %s, 'String must contain at most %s character(s)')", valValue, valValue))
893955
case "gt":
894956
val, err := strconv.Atoi(valValue)
895957
if err != nil {
896958
panic("gt= must be followed by a number")
897959
}
898-
validateStr.WriteString(fmt.Sprintf(".min(%d)", val+1))
960+
refines = append(refines, fmt.Sprintf(".refine((val) => [...val].length > %d, 'String must contain at least %d character(s)')", val, val+1))
899961
case "gte":
900-
validateStr.WriteString(fmt.Sprintf(".min(%s)", valValue))
962+
refines = append(refines, fmt.Sprintf(".refine((val) => [...val].length >= %s, 'String must contain at least %s character(s)')", valValue, valValue))
901963
case "lt":
902964
val, err := strconv.Atoi(valValue)
903965
if err != nil {
904966
panic("lt= must be followed by a number")
905967
}
906-
validateStr.WriteString(fmt.Sprintf(".max(%d)", val-1))
968+
refines = append(refines, fmt.Sprintf(".refine((val) => [...val].length < %d, 'String must contain at most %d character(s)')", val, val-1))
907969
case "lte":
908-
validateStr.WriteString(fmt.Sprintf(".max(%s)", valValue))
970+
refines = append(refines, fmt.Sprintf(".refine((val) => [...val].length <= %s, 'String must contain at most %s character(s)')", valValue, valValue))
909971
case "contains":
910972
validateStr.WriteString(fmt.Sprintf(".includes(\"%s\")", valValue))
911973
case "endswith":

0 commit comments

Comments
 (0)