Skip to content

Commit 59d2ced

Browse files
committed
pr feedback: unique_slice_with_nil_and_zero_value_struct
1 parent 77361f2 commit 59d2ced

File tree

2 files changed

+85
-53
lines changed

2 files changed

+85
-53
lines changed

baked_in.go

Lines changed: 77 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -336,87 +336,111 @@ func isOneOfCI(fl FieldLevel) bool {
336336
func isUnique(fl FieldLevel) bool {
337337
field := fl.Field()
338338
param := fl.Param()
339-
v := reflect.ValueOf(struct{}{})
339+
340+
// sentinel used as map key for nil values
341+
var nilKey = struct{}{}
340342

341343
switch field.Kind() {
344+
342345
case reflect.Slice, reflect.Array:
343-
elem := field.Type().Elem()
344-
if elem.Kind() == reflect.Ptr {
345-
elem = elem.Elem()
346-
}
346+
seen := make(map[interface{}]struct{})
347347

348-
if param == "" {
349-
m := reflect.MakeMap(reflect.MapOf(elem, v.Type()))
350-
zero := reflect.Zero(elem)
348+
for i := 0; i < field.Len(); i++ {
349+
elem := field.Index(i)
350+
351+
// -------- unique (no param) --------
352+
if param == "" {
353+
var key interface{}
354+
355+
if elem.Kind() == reflect.Ptr {
356+
if elem.IsNil() {
357+
key = nilKey
358+
} else {
359+
key = elem.Elem().Interface() // <-- compare underlying value
360+
}
361+
} else {
362+
key = elem.Interface()
363+
}
351364

352-
for i := 0; i < field.Len(); i++ {
353-
e := reflect.Indirect(field.Index(i))
354-
if !e.IsValid() {
355-
m.SetMapIndex(zero, v)
356-
continue
365+
if _, ok := seen[key]; ok {
366+
return false
357367
}
358-
m.SetMapIndex(e, v)
368+
seen[key] = struct{}{}
369+
continue
359370
}
360-
return field.Len() == m.Len()
361-
}
362371

363-
sf, ok := elem.FieldByName(param)
364-
if !ok {
365-
panic(fmt.Sprintf("Bad field name %s", param))
366-
}
372+
// -------- unique=Field --------
367373

368-
sfTyp := sf.Type
369-
if sfTyp.Kind() == reflect.Ptr {
370-
sfTyp = sfTyp.Elem()
371-
}
374+
if elem.Kind() == reflect.Ptr {
375+
if elem.IsNil() {
376+
if _, ok := seen[nilKey]; ok {
377+
return false
378+
}
379+
seen[nilKey] = struct{}{}
380+
continue
381+
}
382+
elem = elem.Elem()
383+
}
372384

373-
m := reflect.MakeMap(reflect.MapOf(sfTyp, v.Type()))
374-
zero := reflect.Zero(sfTyp)
385+
if elem.Kind() != reflect.Struct {
386+
panic(fmt.Sprintf("Bad field type %s", elem.Type()))
387+
}
375388

376-
for i := 0; i < field.Len(); i++ {
377-
parent := reflect.Indirect(field.Index(i))
378-
if !parent.IsValid() {
379-
m.SetMapIndex(zero, v)
380-
continue
389+
sf := elem.FieldByName(param)
390+
if !sf.IsValid() {
391+
panic(fmt.Sprintf("Bad field name %s", param))
381392
}
382393

383-
key := reflect.Indirect(parent.FieldByName(param))
384-
if !key.IsValid() {
385-
m.SetMapIndex(zero, v)
386-
continue
394+
var key interface{}
395+
396+
if sf.Kind() == reflect.Ptr {
397+
if sf.IsNil() {
398+
key = nilKey
399+
} else {
400+
key = sf.Elem().Interface()
401+
}
402+
} else {
403+
key = sf.Interface()
387404
}
388405

389-
m.SetMapIndex(key, v)
406+
if _, ok := seen[key]; ok {
407+
return false
408+
}
409+
seen[key] = struct{}{}
390410
}
391411

392-
return field.Len() == m.Len()
412+
return true
393413

394414
case reflect.Map:
395-
var keyType reflect.Type
396-
if field.Type().Elem().Kind() == reflect.Ptr {
397-
keyType = field.Type().Elem().Elem()
398-
} else {
399-
keyType = field.Type().Elem()
400-
}
401-
402-
m := reflect.MakeMap(reflect.MapOf(keyType, v.Type()))
403-
zero := reflect.Zero(keyType)
415+
seen := make(map[interface{}]struct{})
404416

405417
for _, k := range field.MapKeys() {
406-
val := reflect.Indirect(field.MapIndex(k))
407-
if !val.IsValid() {
408-
m.SetMapIndex(zero, v)
409-
continue
418+
val := field.MapIndex(k)
419+
420+
var key interface{}
421+
422+
if val.Kind() == reflect.Ptr {
423+
if val.IsNil() {
424+
key = nilKey
425+
} else {
426+
key = val.Elem().Interface() // <-- compare underlying value
427+
}
428+
} else {
429+
key = val.Interface()
430+
}
431+
432+
if _, ok := seen[key]; ok {
433+
return false
410434
}
411-
m.SetMapIndex(val, v)
435+
seen[key] = struct{}{}
412436
}
413437

414-
return field.Len() == m.Len()
438+
return true
415439

416440
default:
417441
if parent := fl.Parent(); parent.Kind() == reflect.Struct {
418442
uniqueField := parent.FieldByName(param)
419-
if uniqueField == reflect.ValueOf(nil) {
443+
if !uniqueField.IsValid() {
420444
panic(fmt.Sprintf("Bad field name provided %s", param))
421445
}
422446

validator_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11029,6 +11029,14 @@ func TestUniqueValidationNilPtrSlice(t *testing.T) {
1102911029
t.Fatalf("nil and non-nil map values should pass unique validation, got: %v", errs)
1103011030
}
1103111031
})
11032+
11033+
t.Run("unique_slice_with_nil_and_zero_value_struct", func(t *testing.T) {
11034+
s := []*Inner{nil, {Name: ""}}
11035+
errs := validate.Var(s, "unique")
11036+
if errs != nil {
11037+
t.Fatalf("nil and non-nil map values should pass unique validation, got: %v", errs)
11038+
}
11039+
})
1103211040
}
1103311041

1103411042
func TestHTMLValidation(t *testing.T) {

0 commit comments

Comments
 (0)