@@ -19,7 +19,10 @@ func TestNestedSubstInGenericFunction(t *testing.T) {
1919
2020 func D(){
2121 type E[V any] struct{X V}
22- }`
22+ }
23+
24+ type F[W any] struct{X W}
25+ `
2326
2427 fSet := token .NewFileSet ()
2528 f , err := parser .ParseFile (fSet , "hello.go" , source , 0 )
@@ -33,51 +36,105 @@ func TestNestedSubstInGenericFunction(t *testing.T) {
3336 t .Fatal (err )
3437 }
3538
39+ type namedType struct {
40+ name string // the name of the named type
41+ args []string // type expressions of args for the named type
42+ }
43+
3644 for _ , test := range []struct {
37- fnName string // the name of the nesting function
38- fnArgs []string // type expressions of args for the nesting function
39- stName string // the name of the named type
40- stArgs []string // type expressions of args for the named type
41- want string // expected underlying value after substitution
45+ nesting []namedType
46+ want string // expected underlying value after substitution
4247 }{
4348 {
44- fnName : `A` , fnArgs : []string {`int` },
45- stName : `B` , stArgs : []string {},
49+ nesting : []namedType {
50+ {name : `A` , args : []string {`int` }},
51+ },
4652 want : `struct{X int}` ,
4753 },
4854 {
49- fnName : `A` , fnArgs : []string {`int` },
50- stName : `C` , stArgs : []string {`bool` },
55+ nesting : []namedType {
56+ {name : `A` , args : []string {`int` }},
57+ {name : `B` },
58+ },
59+ want : `struct{X int}` ,
60+ },
61+ {
62+ nesting : []namedType {
63+ {name : `A` , args : []string {`int` }},
64+ {name : `C` , args : []string {`bool` }},
65+ },
5166 want : "struct{X int; Y bool}" ,
5267 },
5368 {
54- fnName : `D` , fnArgs : []string {},
55- stName : `E` , stArgs : []string {`int` },
69+ nesting : []namedType {
70+ {name : `D` },
71+ {name : `E` , args : []string {`int` }},
72+ },
5673 want : "struct{X int}" ,
5774 },
75+ {
76+ nesting : []namedType {
77+ {name : `F` , args : []string {`int` }},
78+ },
79+ want : `struct{X int}` ,
80+ },
5881 } {
59- ctxt := types .NewContext ()
60-
61- fnGen , _ := pkg .Scope ().Lookup (test .fnName ).(* types.Func )
62- if fnGen == nil {
63- t .Fatal ("Failed to find the function " + test .fnName )
82+ if len (test .nesting ) == 0 {
83+ t .Fatal (`Must have at least one names type to instantiate` )
6484 }
65- fnArgs := evalTypeList (t , fSet , pkg , test .fnArgs )
66- fnFunc := types .NewFunc (fnGen .Pos (), pkg , fnGen .Name (), fnGen .Type ().(* types.Signature ))
6785
68- stType , _ := fnFunc .Scope ().Lookup (test .stName ).Type ().(* types.Named )
69- if stType == nil {
70- t .Fatal ("Failed to find the object " + test .fnName + " in function " + test .fnName )
86+ ctxt := types .NewContext ()
87+ var subst * Subster
88+ var obj types.Object
89+ scope := pkg .Scope ()
90+ for _ , nt := range test .nesting {
91+ obj = scope .Lookup (nt .name )
92+ if obj == nil {
93+ t .Fatalf (`Failed to find %s in package scope` , nt .name )
94+ }
95+ if fn , ok := obj .(* types.Func ); ok {
96+ scope = fn .Scope ()
97+ }
98+ args := evalTypeList (t , fSet , pkg , nt .args )
99+ tp := getTypeParams (t , obj .Type ())
100+ subst = New (ctxt , tp , args , subst )
71101 }
72- stArgs := evalTypeList (t , fSet , pkg , test .stArgs )
73102
74- stSubst := NewNested (ctxt , fnFunc , fnArgs , stType .TypeParams (), stArgs )
75- stInst := stSubst .Type (stType .Underlying ())
103+ shouldNotPanic (t , func () {
104+ stInst := subst .Type (obj .Type ().Underlying ())
105+ if got := stInst .String (); got != test .want {
106+ t .Errorf ("%s.typ(%s) = %v, want %v" , subst , obj .Type ().Underlying (), got , test .want )
107+ }
108+ })
109+ }
110+ }
111+
112+ func shouldNotPanic (t * testing.T , f func ()) {
113+ t .Helper ()
114+ defer func () {
115+ if r := recover (); r != nil {
116+ t .Errorf (`panicked: %v` , r )
117+ }
118+ }()
119+ f ()
120+ }
76121
77- if got := stInst .String (); got != test .want {
78- t .Errorf ("subst{%v->%v}.typ(%s) = %v, want %v" , test .stName , test .stArgs , stType .Underlying (), got , test .want )
122+ func getTypeParams (t * testing.T , typ types.Type ) * types.TypeParamList {
123+ switch typ := typ .(type ) {
124+ case * types.Named :
125+ return typ .TypeParams ()
126+ case * types.Signature :
127+ if tp := typ .RecvTypeParams (); tp != nil && tp .Len () > 0 {
128+ return tp
79129 }
130+ return typ .TypeParams ()
131+ case interface { Elem () types.Type }:
132+ // Pointer, slice, array, map, and channel types.
133+ return getTypeParams (t , typ .Elem ())
134+ default :
135+ t .Fatalf (`getTypeParams(%v) hit unexpected type` , typ )
80136 }
137+ return nil
81138}
82139
83140func evalType (t * testing.T , fSet * token.FileSet , pkg * types.Package , expr string ) types.Type {
0 commit comments