Skip to content

Commit a9f17ae

Browse files
committed
Add Type.IsComparable()
1 parent 0359904 commit a9f17ae

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

v2/parser/parse.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type
652652
switch t := in.(type) {
653653
case *gotypes.Struct:
654654
out := u.Type(name)
655+
out.GoType = in
655656
if out.Kind != types.Unknown {
656657
return out
657658
}
@@ -670,6 +671,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type
670671
return out
671672
case *gotypes.Map:
672673
out := u.Type(name)
674+
out.GoType = in
673675
if out.Kind != types.Unknown {
674676
return out
675677
}
@@ -679,6 +681,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type
679681
return out
680682
case *gotypes.Pointer:
681683
out := u.Type(name)
684+
out.GoType = in
682685
if out.Kind != types.Unknown {
683686
return out
684687
}
@@ -687,6 +690,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type
687690
return out
688691
case *gotypes.Slice:
689692
out := u.Type(name)
693+
out.GoType = in
690694
if out.Kind != types.Unknown {
691695
return out
692696
}
@@ -695,6 +699,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type
695699
return out
696700
case *gotypes.Array:
697701
out := u.Type(name)
702+
out.GoType = in
698703
if out.Kind != types.Unknown {
699704
return out
700705
}
@@ -704,6 +709,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type
704709
return out
705710
case *gotypes.Chan:
706711
out := u.Type(name)
712+
out.GoType = in
707713
if out.Kind != types.Unknown {
708714
return out
709715
}
@@ -717,13 +723,15 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type
717723
Package: "", // This is a magic package name in the Universe.
718724
Name: t.Name(),
719725
})
726+
out.GoType = in
720727
if out.Kind != types.Unknown {
721728
return out
722729
}
723730
out.Kind = types.Unsupported
724731
return out
725732
case *gotypes.Signature:
726733
out := u.Type(name)
734+
out.GoType = in
727735
if out.Kind != types.Unknown {
728736
return out
729737
}
@@ -732,6 +740,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type
732740
return out
733741
case *gotypes.Interface:
734742
out := u.Type(name)
743+
out.GoType = in
735744
if out.Kind != types.Unknown {
736745
return out
737746
}
@@ -754,6 +763,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type
754763
case *gotypes.Named, *gotypes.Basic, *gotypes.Map, *gotypes.Slice:
755764
name := goNameToName(t.String())
756765
out = u.Type(name)
766+
out.GoType = in
757767
if out.Kind != types.Unknown {
758768
return out
759769
}
@@ -776,6 +786,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type
776786
}
777787

778788
if out := u.Type(name); out.Kind != types.Unknown {
789+
out.GoType = in
779790
return out // short circuit if we've already made this.
780791
}
781792
out = p.walkType(u, &name, t.Underlying())
@@ -817,6 +828,7 @@ func (p *Parser) walkType(u types.Universe, useName *types.Name, in gotypes.Type
817828
}
818829
default:
819830
out := u.Type(name)
831+
out.GoType = in
820832
if out.Kind != types.Unknown {
821833
return out
822834
}

v2/parser/parse_test.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"testing"
2626

2727
"github.com/google/go-cmp/cmp"
28+
"github.com/google/go-cmp/cmp/cmpopts"
2829
"golang.org/x/tools/go/packages"
2930
"k8s.io/gengo/v2/types"
3031
)
@@ -1118,8 +1119,14 @@ func TestStructParse(t *testing.T) {
11181119
if st == nil || st.Kind == types.Unknown {
11191120
t.Fatalf("type %s not found", expected.Name.Name)
11201121
}
1121-
if e, a := expected, st; !reflect.DeepEqual(e, a) {
1122-
t.Errorf("wanted, got:\n%#v\n%#v\n%s", e, a, cmp.Diff(e, a))
1122+
if st.GoType == nil {
1123+
t.Errorf("type %s did not have GoType", expected.Name.Name)
1124+
}
1125+
opts := []cmp.Option{
1126+
cmpopts.IgnoreTypes(types.Type{}, "GoType"),
1127+
}
1128+
if e, a := expected, st; !cmp.Equal(e, a, opts...) {
1129+
t.Errorf("wanted, got:\n%#v\n%#v\n%s", e, a, cmp.Diff(e, a, opts...))
11231130
}
11241131
})
11251132
}

v2/types/types.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ limitations under the License.
1616

1717
package types
1818

19-
import "strings"
19+
import (
20+
gotypes "go/types"
21+
"strings"
22+
)
2023

2124
// Ref makes a reference to the given type. It can only be used for e.g.
2225
// passing to namers.
@@ -358,6 +361,9 @@ type Type struct {
358361

359362
// If Kind == Array
360363
Len int64
364+
365+
// The underlying Go type.
366+
GoType gotypes.Type
361367
}
362368

363369
// String returns the name of the type.
@@ -402,6 +408,11 @@ func (t *Type) IsAnonymousStruct() bool {
402408
return (t.Kind == Struct && t.Name.Name == "struct{}") || (t.Kind == Alias && t.Underlying.IsAnonymousStruct())
403409
}
404410

411+
// IsComparable returns whether the type is comparable.
412+
func (t *Type) IsComparable() bool {
413+
return gotypes.Comparable(t.GoType)
414+
}
415+
405416
// A single struct member
406417
type Member struct {
407418
// The name of the member.

0 commit comments

Comments
 (0)