Skip to content

Commit e82fa78

Browse files
working on identifying where nesting types should get injected
1 parent 72498bf commit e82fa78

File tree

8 files changed

+223
-37
lines changed

8 files changed

+223
-37
lines changed

compiler/compiler_test.go

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -750,8 +750,53 @@ func TestArchiveSelectionAfterSerialization(t *testing.T) {
750750
}
751751
}
752752

753-
func TestNestedTypesDuplicateDecls(t *testing.T) {
754-
// This is a subset of the type param nested test.
753+
func TestNestedConcreteTypeInGenericFunc(t *testing.T) {
754+
// This is a test of a type defined inside a generic function
755+
// that uses the type parameter of the function as a field type.
756+
// The `T` type is unique for each instance of `F`.
757+
// The use of `A` as a field is do demonstrate the difference in the types
758+
// however even if T had no fields, the type would still be different.
759+
//
760+
// Change `print(F[?]())` to `fmt.Printf("%T\n", F[?]())` for
761+
// golang playground to print the type of T in the different F instances.
762+
// (I just didn't want this test to depend on `fmt` when it doesn't need to.)
763+
764+
src := `
765+
package main
766+
767+
func F[A any]() any {
768+
type T struct{
769+
a A
770+
}
771+
return T{}
772+
}
773+
774+
func main() {
775+
type Int int
776+
777+
print(F[int]())
778+
print(F[Int]())
779+
}`
780+
781+
srcFiles := []srctesting.Source{{Name: `main.go`, Contents: []byte(src)}}
782+
root := srctesting.ParseSources(t, srcFiles, nil)
783+
archives := compileProject(t, root, false)
784+
mainPkg := archives[root.PkgPath]
785+
insts := collectDeclInstances(t, mainPkg)
786+
787+
exp := []string{
788+
`F[main.Int·2]`,
789+
`F[int]`,
790+
`T[main.Int·2]`, // `T` from `F[main.Int·2]`
791+
`T[int]`, // `T` from `F[Int]`
792+
}
793+
if diff := cmp.Diff(exp, insts); len(diff) > 0 {
794+
t.Errorf("the instances of generics are different:\n%s", diff)
795+
}
796+
}
797+
798+
func TestNestedGenericTypeInGenericFunc(t *testing.T) {
799+
// This is a subset of the type param nested test from the go repo.
755800
// See https://github.com/golang/go/blob/go1.19.13/test/typeparam/nested.go
756801
// The test is failing because nested types aren't being typed differently.
757802
// For example the type of `T[int]` below is different based on `F[X]`
@@ -761,49 +806,56 @@ func TestNestedTypesDuplicateDecls(t *testing.T) {
761806
src := `
762807
package main
763808
764-
type intish interface { ~int }
765-
766-
func F[A intish]() {
767-
type T[B intish] struct{}
768-
print(T[int]{})
809+
func F[A any]() any {
810+
type T[B any] struct{
811+
a A
812+
b B
813+
}
814+
return T[int]{}
769815
}
770816
771817
func main() {
772818
type Int int
773819
774-
F[int]()
775-
F[Int]()
820+
print(F[int]())
821+
print(F[Int]())
776822
}`
777823

778824
srcFiles := []srctesting.Source{{Name: `main.go`, Contents: []byte(src)}}
779825
root := srctesting.ParseSources(t, srcFiles, nil)
780826
archives := compileProject(t, root, false)
781827
mainPkg := archives[root.PkgPath]
828+
insts := collectDeclInstances(t, mainPkg)
829+
830+
exp := []string{
831+
`F[Int]`,
832+
`F[int]`,
833+
`T[Int;int]`,
834+
`T[int;int]`,
835+
}
836+
if diff := cmp.Diff(exp, insts); len(diff) > 0 {
837+
t.Errorf("the instances of generics are different:\n%s", diff)
838+
}
839+
}
840+
841+
func collectDeclInstances(t *testing.T, pkg *Archive) []string {
842+
t.Helper()
782843

783844
// Regex to match strings like `Foo[42 /* bar */] =` and capture
784845
// the name (`Foo`), the index (`42`), and the instance type (`bar`).
785846
rex := regexp.MustCompile(`^\s*(\w+)\s*\[\s*(\d+)\s*\/\*(.+)\*\/\s*\]\s*\=`)
786847

787-
// Collect all instances of generics (i.e. `Foo[bar]`) written to the decl code.
848+
// Collect all instances of generics (e.g. `Foo[bar]`) written to the decl code.
788849
insts := []string{}
789-
for _, decl := range mainPkg.Declarations {
850+
for _, decl := range pkg.Declarations {
790851
if match := rex.FindAllStringSubmatch(string(decl.DeclCode), 1); len(match) > 0 {
791852
instance := match[0][1] + `[` + strings.TrimSpace(match[0][3]) + `]`
792853
instance = strings.ReplaceAll(instance, `command-line-arguments.`, ``)
793854
insts = append(insts, instance)
794855
}
795856
}
796857
sort.Strings(insts)
797-
798-
exp := []string{
799-
`F[Int]`,
800-
`F[int]`,
801-
`T[Int;int]`,
802-
`T[int;int]`,
803-
}
804-
if diff := cmp.Diff(exp, insts); len(diff) > 0 {
805-
t.Errorf("the instances of generics are different:\n%s", diff)
806-
}
858+
return insts
807859
}
808860

809861
func compareOrder(t *testing.T, sourceFiles []srctesting.Source, minify bool) {

compiler/decls.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,9 @@ func (fc *funcContext) newNamedTypeInstDecl(inst typeparams.Instance) (*Decl, er
518518
// initialize themselves.
519519
switch t := underlying.(type) {
520520
case *types.Array, *types.Chan, *types.Interface, *types.Map, *types.Pointer, *types.Slice, *types.Signature, *types.Struct:
521+
522+
fmt.Printf(">>>[A] (%[1]T): %[1]v\n", t) // TODO(grantnelson-wf): remove
523+
521524
d.TypeInitCode = fc.CatchOutput(0, func() {
522525
fc.Printf("%s.init(%s);", fc.instName(inst), fc.initArgs(t))
523526
})

compiler/internal/typeparams/collect.go

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ func (r *Resolver) SubstituteSelection(sel typesutil.Selection) typesutil.Select
8282
}
8383
}
8484

85+
func (r *Resolver) String() string { // TODO(grantnelson-wf): remove
86+
if r == nil || r.subster == nil {
87+
return "<nil Resolver>"
88+
}
89+
return r.subster.String()
90+
}
91+
8592
// ToSlice converts TypeParamList into a slice with the same order of entries.
8693
func ToSlice(tpl *types.TypeParamList) []*types.TypeParam {
8794
result := make([]*types.TypeParam, tpl.Len())
@@ -114,10 +121,21 @@ func (c *visitor) Visit(n ast.Node) (w ast.Visitor) {
114121
}
115122

116123
instance, ok := c.info.Instances[ident]
117-
if !ok {
124+
if ok {
125+
c.addNamedInstance(ident, instance)
126+
return
127+
}
128+
129+
def, ok := c.info.Defs[ident]
130+
if ok && def != nil {
131+
c.addNestedNamed(ident, def)
118132
return
119133
}
120134

135+
return
136+
}
137+
138+
func (c *visitor) addNamedInstance(ident *ast.Ident, instance types.Instance) {
121139
obj := c.info.ObjectOf(ident)
122140

123141
// For types embedded in structs, the object the identifier resolves to is a
@@ -145,7 +163,21 @@ func (c *visitor) Visit(n ast.Node) (w ast.Visitor) {
145163
})
146164
}
147165
}
148-
return
166+
167+
fmt.Printf(">>>[X] %s => %v\n", ident.Name, obj) // TODO(grantnelson-wf): remove
168+
}
169+
170+
// TODO(grantnelson-wf): finish or remove
171+
func (c *visitor) addNestedNamed(ident *ast.Ident, obj types.Object) {
172+
typ := obj.Type()
173+
if ptr, ok := typ.(*types.Pointer); ok {
174+
typ = ptr.Elem()
175+
}
176+
if t, ok := typ.(*types.Named); ok {
177+
obj = t.Obj()
178+
}
179+
180+
fmt.Printf(">>>[Y] %s => %v\n", ident.Name, obj) // TODO(grantnelson-wf): remove
149181
}
150182

151183
// seedVisitor implements ast.Visitor that collects information necessary to
@@ -173,11 +205,11 @@ func (c *seedVisitor) Visit(n ast.Node) ast.Visitor {
173205
sig := obj.Type().(*types.Signature)
174206
if sig.TypeParams().Len() != 0 || sig.RecvTypeParams().Len() != 0 {
175207
c.objMap[obj] = n
176-
return &seedVisitor{
208+
return newPrinter(&seedVisitor{
177209
visitor: c.visitor,
178210
objMap: c.objMap,
179211
mapOnly: true,
180-
}
212+
}, "FuncDeclSeed")
181213
}
182214
case *ast.TypeSpec:
183215
obj := c.info.Defs[n.Name]
@@ -236,7 +268,7 @@ func (c *Collector) Scan(pkg *types.Package, files ...*ast.File) {
236268
objMap: objMap,
237269
}
238270
for _, file := range files {
239-
ast.Walk(&sc, file)
271+
ast.Walk(newPrinter(&sc, "Seed"), file)
240272
}
241273

242274
for iset := c.Instances.Pkg(pkg); !iset.exhausted(); {
@@ -248,15 +280,59 @@ func (c *Collector) Scan(pkg *types.Package, files ...*ast.File) {
248280
resolver: NewResolver(c.TContext, ToSlice(SignatureTypeParams(typ)), inst.TArgs),
249281
info: c.Info,
250282
}
251-
ast.Walk(&v, objMap[inst.Object])
283+
ast.Walk(newPrinter(&v, "Signature"), objMap[inst.Object])
252284
case *types.Named:
253285
obj := typ.Obj()
254286
v := visitor{
255287
instances: c.Instances,
256288
resolver: NewResolver(c.TContext, ToSlice(typ.TypeParams()), inst.TArgs),
257289
info: c.Info,
258290
}
259-
ast.Walk(&v, objMap[obj])
291+
ast.Walk(newPrinter(&v, "Named"), objMap[obj])
292+
}
293+
}
294+
}
295+
296+
type printer struct { // TODO(grantnelson-wf): remove
297+
inner ast.Visitor
298+
title string
299+
indent string
300+
}
301+
302+
func newPrinter(inner ast.Visitor, title string) *printer {
303+
return &printer{
304+
inner: inner,
305+
title: title,
306+
indent: ``,
307+
}
308+
}
309+
310+
func (p *printer) Visit(n ast.Node) (w ast.Visitor) {
311+
if n == nil {
312+
if len(p.indent) >= 2 {
313+
p.indent = p.indent[:len(p.indent)-2]
314+
}
315+
} else {
316+
p.indent += " "
317+
if id, ok := n.(*ast.Ident); ok {
318+
fmt.Printf("%s%s(%T)%q\n", p.title, p.indent, n, id.Name)
319+
} else {
320+
fmt.Printf("%s%s(%T)\n", p.title, p.indent, n)
321+
}
322+
}
323+
v2 := p.inner.Visit(n)
324+
if v2 != nil {
325+
if v2 == p.inner {
326+
return p
327+
}
328+
if _, ok := v2.(*printer); ok {
329+
return v2
330+
}
331+
v2 = &printer{
332+
inner: v2,
333+
title: p.title,
334+
indent: p.indent,
260335
}
261336
}
337+
return v2
262338
}

compiler/internal/typeparams/instance.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ import (
1515
type Instance struct {
1616
Object types.Object // Object to be instantiated.
1717
TArgs typesutil.TypeList // Type params to instantiate with.
18+
19+
// TNest is the type params of the function this object was nested with-in.
20+
// e.g. In `func A[X any]() { type B[Y any] struct {} }` the `X`
21+
// from `A` is the context of `B[Y]` thus creating `B[X;Y]`.
22+
TNest typesutil.TypeList
1823
}
1924

2025
// String returns a string representation of the Instance.
@@ -32,11 +37,29 @@ func (i *Instance) String() string {
3237

3338
// TypeString returns a Go type string representing the instance (suitable for %T verb).
3439
func (i *Instance) TypeString() string {
40+
return fmt.Sprintf("%s.%s%s", i.Object.Pkg().Name(), i.Object.Name(), i.typeParamsString())
41+
}
42+
43+
// typeParamsString returns part of a Go type string that represents the type
44+
// parameters of the instance including the nesting type parameters, e.g. [X;Y,Z].
45+
func (i *Instance) typeParamsString() string {
46+
hasNest := len(i.TNest) > 0
47+
hasArgs := len(i.TArgs) > 0
3548
tArgs := ""
36-
if len(i.TArgs) > 0 {
37-
tArgs = "[" + i.TArgs.String() + "]"
49+
if hasNest || hasArgs {
50+
tArgs = "["
51+
if hasNest {
52+
tArgs = i.TNest.String()
53+
if hasArgs {
54+
tArgs += ";"
55+
}
56+
}
57+
if hasArgs {
58+
tArgs = i.TArgs.String()
59+
}
60+
tArgs += "]"
3861
}
39-
return fmt.Sprintf("%s.%s%s", i.Object.Pkg().Name(), i.Object.Name(), tArgs)
62+
return tArgs
4063
}
4164

4265
// IsTrivial returns true if this is an instance of a non-generic object.

compiler/internal/typeparams/map.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@ type InstanceMap[V any] struct {
3838
// If the given key isn't found, an empty bucket and -1 are returned.
3939
func (im *InstanceMap[V]) findIndex(key Instance) (mapBucket[V], int) {
4040
if im != nil && im.data != nil {
41-
bucket := im.data[key.Object][typeHash(im.hasher, key.TArgs...)]
41+
bucket := im.data[key.Object][typeHash(im.hasher, key.TNest, key.TArgs)]
4242
for i, candidate := range bucket {
43-
if candidate != nil && candidate.key.TArgs.Equal(key.TArgs) {
43+
if candidate != nil &&
44+
candidate.key.TNest.Equal(key.TNest) &&
45+
candidate.key.TArgs.Equal(key.TArgs) {
4446
return bucket, i
4547
}
4648
}
@@ -82,15 +84,15 @@ func (im *InstanceMap[V]) Set(key Instance, value V) V {
8284
if _, ok := im.data[key.Object]; !ok {
8385
im.data[key.Object] = mapBuckets[V]{}
8486
}
85-
bucketID := typeHash(im.hasher, key.TArgs...)
87+
bucketID := typeHash(im.hasher, key.TNest, key.TArgs)
8688

8789
// If there is already an identical key in the map, override the entry value.
8890
hole := -1
8991
bucket := im.data[key.Object][bucketID]
9092
for i, candidate := range bucket {
9193
if candidate == nil {
9294
hole = i
93-
} else if candidate.key.TArgs.Equal(key.TArgs) {
95+
} else if candidate.key.TNest.Equal(key.TNest) && candidate.key.TArgs.Equal(key.TArgs) {
9496
old := candidate.value
9597
candidate.value = value
9698
return old
@@ -185,8 +187,11 @@ func (im *InstanceMap[V]) String() string {
185187
// Provided hasher is used to compute hashes of individual types, which are
186188
// xor'ed together. Xor preserves bit distribution property, so the combined
187189
// hash should be as good for bucketing, as the original.
188-
func typeHash(hasher typeutil.Hasher, types ...types.Type) uint32 {
190+
func typeHash(hasher typeutil.Hasher, nestTypes []types.Type, types []types.Type) uint32 {
189191
var hash uint32
192+
for _, typ := range nestTypes {
193+
hash ^= hasher.Hash(typ)
194+
}
190195
for _, typ := range types {
191196
hash ^= hasher.Hash(typ)
192197
}

compiler/package.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,19 @@ func (fc *funcContext) initArgs(ty types.Type) string {
327327
}
328328
return fmt.Sprintf(`"%s", [%s]`, pkgPath, strings.Join(fields, ", "))
329329
case *types.TypeParam:
330-
err := bailout(fmt.Errorf(`%v has unexpected generic type parameter %T`, ty, ty))
330+
tr := fc.typeResolver.Substitute(ty)
331+
332+
fmt.Printf(">>>[1] ty: (%[1]T): %[1]v\n", ty) // TODO(grantnelson-wf): remove
333+
fmt.Printf(">>>[1] tr: (%[1]T): %[1]v\n", tr)
334+
fmt.Printf(">>>[1] funcRef: %q\n", fc.funcRef)
335+
fmt.Printf(">>>[1]-----------------\n")
336+
fmt.Printf("%s\n", fc.typeResolver.String())
337+
fmt.Printf(">>>[1]-----------------\n")
338+
339+
if tr != ty {
340+
return fc.initArgs(tr)
341+
}
342+
err := bailout(fmt.Errorf(`"%v" has unexpected generic type parameter %T`, ty, ty))
331343
panic(err)
332344
default:
333345
err := bailout(fmt.Errorf("%v has unexpected type %T", ty, ty))

0 commit comments

Comments
 (0)