diff --git a/generator/generator.go b/generator/generator.go index cfcd463c..0ecc3f55 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -459,7 +459,7 @@ func typeSpecs(f *ast.File) []*ast.TypeSpec { return result } -func getEmbeddedMethods(t ast.Expr, pr typePrinter, input targetProcessInput) (param genericParam, methods methodsList, err error) { +func getEmbeddedMethods(t ast.Expr, pr typePrinter, input targetProcessInput, checkInterface bool) (param genericParam, methods methodsList, err error) { param.Name, err = pr.PrintType(t) if err != nil { return @@ -471,13 +471,13 @@ func getEmbeddedMethods(t ast.Expr, pr typePrinter, input targetProcessInput) (p return case *ast.Ident: - methods, err = processIdent(v, input) + methods, err = processIdent(v, input, checkInterface) return } return } -func processEmbedded(t ast.Expr, pr typePrinter, input targetProcessInput) (genericParam genericParam, embeddedMethods methodsList, err error) { +func processEmbedded(t ast.Expr, pr typePrinter, input targetProcessInput, checkInterface bool) (genericParam genericParam, embeddedMethods methodsList, err error) { var x ast.Expr var hasGenericsParams bool var genericParams genericParams @@ -486,8 +486,12 @@ func processEmbedded(t ast.Expr, pr typePrinter, input targetProcessInput) (gene case *ast.IndexExpr: x = v.X hasGenericsParams = true - - genericParam, _, err = processEmbedded(v.Index, pr, input) + // Don't check if embedded interface's generic params are also interfaces, e.g. given the interface: + // type SomeInterface { + // EmbeddedGenericInterface[Bar] + // } + // we won't be checking if Bar is also an interface + genericParam, _, err = processEmbedded(v.Index, pr, input, false) if err != nil { return } @@ -501,7 +505,12 @@ func processEmbedded(t ast.Expr, pr typePrinter, input targetProcessInput) (gene if v.Indices != nil { for _, index := range v.Indices { - genericParam, _, err = processEmbedded(index, pr, input) + // Don't check if embedded interface's generic params are also interfaces, e.g. given the interface: + // type SomeInterface { + // EmbeddedGenericInterface[Bar] + // } + // we won't be checking if Bar is also an interface + genericParam, _, err = processEmbedded(index, pr, input, false) if err != nil { return } @@ -515,7 +524,7 @@ func processEmbedded(t ast.Expr, pr typePrinter, input targetProcessInput) (gene } input.genericParams = genericParams - genericParam, embeddedMethods, err = getEmbeddedMethods(x, pr, input) + genericParam, embeddedMethods, err = getEmbeddedMethods(x, pr, input, checkInterface) if err != nil { return } @@ -551,7 +560,7 @@ func processInterface(it *ast.InterfaceType, targetInput targetProcessInput) (me } default: - _, embeddedMethods, err = processEmbedded(v, pr, targetInput) + _, embeddedMethods, err = processEmbedded(v, pr, targetInput, true) } if err != nil { @@ -618,19 +627,23 @@ func mergeMethods(methods, embeddedMethods methodsList) (methodsList, error) { var errNotAnInterface = errors.New("embedded type is not an interface") -func processIdent(i *ast.Ident, input targetProcessInput) (methodsList, error) { +func processIdent(i *ast.Ident, input targetProcessInput, checkInterface bool) (methodsList, error) { var embeddedInterface *ast.InterfaceType var genericsTypes genericTypes for _, t := range input.types { if t.Name.Name == i.Name { var ok bool embeddedInterface, ok = t.Type.(*ast.InterfaceType) - if !ok { - return nil, errors.Wrap(errNotAnInterface, t.Name.Name) + if ok { + genericsTypes = buildGenericTypesFromSpec(t, input.types, input.typesPrefix) + break + } + + if !checkInterface { + break } - genericsTypes = buildGenericTypesFromSpec(t, input.types, input.typesPrefix) - break + return nil, errors.Wrap(errNotAnInterface, t.Name.Name) } } diff --git a/generator/generator_test.go b/generator/generator_test.go index 1e80cee9..fbe5ccff 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -111,8 +111,9 @@ func Test_findImportPathForName(t *testing.T) { func Test_processIdent(t *testing.T) { type args struct { - i *ast.Ident - input targetProcessInput + i *ast.Ident + input targetProcessInput + toCheckForInterface bool } tests := []struct { name string @@ -129,12 +130,24 @@ func Test_processIdent(t *testing.T) { input: targetProcessInput{ types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "name"}, Type: &ast.StructType{}}}, }, + toCheckForInterface: true, }, wantErr: true, inspectErr: func(err error, t *testing.T) { assert.Equal(t, errNotAnInterface, errors.Cause(err)) }, }, + { + name: "not an interface but no need to check", + args: args{ + i: &ast.Ident{Name: "name"}, + input: targetProcessInput{ + types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "name"}, Type: &ast.StructType{}}}, + }, + toCheckForInterface: false, + }, + wantErr: false, + }, { name: "embedded interface found", args: args{ @@ -142,6 +155,7 @@ func Test_processIdent(t *testing.T) { input: targetProcessInput{ types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "name"}, Type: &ast.InterfaceType{}}}, }, + toCheckForInterface: true, }, wantErr: false, }, @@ -152,7 +166,7 @@ func Test_processIdent(t *testing.T) { mc := minimock.NewController(t) defer mc.Wait(time.Second) - got1, err := processIdent(tt.args.i, tt.args.input) + got1, err := processIdent(tt.args.i, tt.args.input, tt.args.toCheckForInterface) assert.Equal(t, tt.want1, got1, "processIdent returned unexpected result") diff --git a/printer/printer.go b/printer/printer.go index e5914404..c98bf053 100644 --- a/printer/printer.go +++ b/printer/printer.go @@ -67,6 +67,10 @@ func (p *Printer) PrintType(node ast.Node) (string, error) { return p.printStruct(t) case *ast.Ident: return p.printIdent(t) + case *ast.IndexExpr: + return p.printGeneric(t) + case *ast.IndexListExpr: + return p.printGenericList(t) } err := printer.Fprint(p.buf, p.fs, node) @@ -151,6 +155,43 @@ func (p *Printer) printIdent(i *ast.Ident) (string, error) { return p.buf.String(), err } +func (p *Printer) printGeneric(pt *ast.IndexExpr) (string, error) { + t, err := p.PrintType(pt.X) + if err != nil { + return "", err + } + + generic, err := p.PrintType(pt.Index) + if err != nil { + return "", err + } + + return t + "[" + generic + "]", nil +} + +func (p *Printer) printGenericList(pt *ast.IndexListExpr) (string, error) { + t, err := p.PrintType(pt.X) + if err != nil { + return "", err + } + + baseStr := t + "[" + for i, expr := range pt.Indices { + generic, err := p.PrintType(expr) + if err != nil { + return "", err + } + + if i == len(pt.Indices)-1 { + baseStr = baseStr + generic + "]" + } else { + baseStr = baseStr + generic + ", " + } + } + + return baseStr, nil +} + func (p *Printer) printPointer(pt *ast.StarExpr) (string, error) { pointerTo, err := p.PrintType(pt.X) if err != nil { diff --git a/printer/printer_test.go b/printer/printer_test.go index 73143c3c..cc0dab35 100644 --- a/printer/printer_test.go +++ b/printer/printer_test.go @@ -776,6 +776,199 @@ func TestPrinter_printIdent(t *testing.T) { } } +func TestPrinter_printGeneric(t *testing.T) { + tests := []struct { + name string + init func(t minimock.Tester) *Printer + inspect func(r *Printer, t *testing.T) + + indexExpr *ast.IndexExpr + + want1 string + wantErr bool + inspectErr func(err error, t *testing.T) + }{ + { + name: "success", + indexExpr: &ast.IndexExpr{ + X: &ast.Ident{ + Name: "Bar", + }, + Index: &ast.Ident{ + Name: "Baz", + }, + }, + init: func(t minimock.Tester) *Printer { + return &Printer{ + typesPrefix: "prefix", + fs: token.NewFileSet(), + buf: bytes.NewBuffer([]byte{}), + types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "Bar"}}, {Name: &ast.Ident{Name: "Baz"}}}, + } + }, + want1: "prefix.Bar[prefix.Baz]", + wantErr: false, + }, + { + name: "success, generic from other package", + indexExpr: &ast.IndexExpr{ + X: &ast.Ident{ + Name: "Bar", + }, + Index: &ast.SelectorExpr{ + X: &ast.Ident{ + Name: "otherpkg", + }, + Sel: &ast.Ident{ + Name: "Baz", + }, + }, + }, + init: func(t minimock.Tester) *Printer { + return &Printer{ + typesPrefix: "prefix", + fs: token.NewFileSet(), + buf: bytes.NewBuffer([]byte{}), + types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "Bar"}}}, + } + }, + want1: "prefix.Bar[otherpkg.Baz]", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mc := minimock.NewController(t) + defer mc.Wait(time.Second) + + receiver := tt.init(mc) + + got1, err := receiver.printGeneric(tt.indexExpr) + + if tt.inspect != nil { + tt.inspect(receiver, t) + } + + assert.Equal(t, tt.want1, got1, "Printer.printGeneric returned unexpected result") + + if tt.wantErr { + if assert.Error(t, err) && tt.inspectErr != nil { + tt.inspectErr(err, t) + } + } else { + assert.NoError(t, err) + } + + }) + } +} + +func TestPrinter_printGenericList(t *testing.T) { + tests := []struct { + name string + init func(t minimock.Tester) *Printer + inspect func(r *Printer, t *testing.T) + + indexListExpr *ast.IndexListExpr + + want1 string + wantErr bool + inspectErr func(err error, t *testing.T) + }{ + { + name: "success", + indexListExpr: &ast.IndexListExpr{ + X: &ast.Ident{ + Name: "Bar", + }, + Indices: []ast.Expr{ + &ast.Ident{ + Name: "Baz", + }, + &ast.Ident{ + Name: "Bak", + }, + }, + }, + init: func(t minimock.Tester) *Printer { + return &Printer{ + typesPrefix: "prefix", + fs: token.NewFileSet(), + buf: bytes.NewBuffer([]byte{}), + types: []*ast.TypeSpec{ + {Name: &ast.Ident{Name: "Bar"}}, + {Name: &ast.Ident{Name: "Baz"}}, + {Name: &ast.Ident{Name: "Bak"}}, + }, + } + }, + want1: "prefix.Bar[prefix.Baz, prefix.Bak]", + wantErr: false, + }, + { + name: "success, generic from other package", + indexListExpr: &ast.IndexListExpr{ + X: &ast.Ident{ + Name: "Bar", + }, + Indices: []ast.Expr{ + &ast.Ident{ + Name: "Baz", + }, + &ast.SelectorExpr{ + X: &ast.Ident{ + Name: "otherpkg", + }, + Sel: &ast.Ident{ + Name: "Bak", + }, + }, + }, + }, + init: func(t minimock.Tester) *Printer { + return &Printer{ + typesPrefix: "prefix", + fs: token.NewFileSet(), + buf: bytes.NewBuffer([]byte{}), + types: []*ast.TypeSpec{ + {Name: &ast.Ident{Name: "Bar"}}, + {Name: &ast.Ident{Name: "Baz"}}, + }, + } + }, + want1: "prefix.Bar[prefix.Baz, otherpkg.Bak]", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mc := minimock.NewController(t) + defer mc.Wait(time.Second) + + receiver := tt.init(mc) + + got1, err := receiver.printGenericList(tt.indexListExpr) + + if tt.inspect != nil { + tt.inspect(receiver, t) + } + + assert.Equal(t, tt.want1, got1, "Printer.printGenericList returned unexpected result") + + if tt.wantErr { + if assert.Error(t, err) && tt.inspectErr != nil { + tt.inspectErr(err, t) + } + } else { + assert.NoError(t, err) + } + + }) + } +} + func TestPrinter_PrintType(t *testing.T) { tests := []struct { name string @@ -876,6 +1069,36 @@ func TestPrinter_PrintType(t *testing.T) { }, want1: "package.Identifier", }, + { + name: "generic type", + node: &ast.IndexExpr{X: &ast.Ident{Name: "Bar"}, Index: &ast.Ident{Name: "string"}}, + init: func(t minimock.Tester) *Printer { + return &Printer{ + fs: token.NewFileSet(), + buf: bytes.NewBuffer([]byte{}), + types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "Bar"}}}, + } + }, + want1: "Bar[string]", + }, + { + name: "generic list type", + node: &ast.IndexListExpr{ + X: &ast.Ident{Name: "Bar"}, + Indices: []ast.Expr{ + &ast.Ident{Name: "string"}, + &ast.Ident{Name: "int"}, + }, + }, + init: func(t minimock.Tester) *Printer { + return &Printer{ + fs: token.NewFileSet(), + buf: bytes.NewBuffer([]byte{}), + types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "Bar"}}}, + } + }, + want1: "Bar[string, int]", + }, } for _, tt := range tests {