diff --git a/go.mod b/go.mod index 63a30d2..212fa9c 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,7 @@ module github.com/jfeliu007/goplantuml -go 1.17 +go 1.18 -require ( - github.com/spf13/afero v1.8.2 - golang.org/x/text v0.3.7 // indirect -) +require github.com/spf13/afero v1.8.2 + +require golang.org/x/text v0.3.7 // indirect diff --git a/parser/class_parser.go b/parser/class_parser.go index 7526b00..491d68d 100644 --- a/parser/class_parser.go +++ b/parser/class_parser.go @@ -10,7 +10,6 @@ call the Render() function and this will return a string with the class diagram. See github.com/jfeliu007/goplantuml/cmd/goplantuml/main.go for a command that uses this functions and outputs the text to the console. - */ package parser @@ -221,17 +220,19 @@ func (p *ClassParser) parsePackage(node ast.Node) { for fileName := range pack.Files { sortedFiles = append(sortedFiles, fileName) } + sort.Strings(sortedFiles) for _, fileName := range sortedFiles { + if strings.HasSuffix(fileName, "_test.go") { + continue + } - if !strings.HasSuffix(fileName, "_test.go") { - f := pack.Files[fileName] - for _, d := range f.Imports { - p.parseImports(d) - } - for _, d := range f.Decls { - p.parseFileDeclarations(d) - } + f := pack.Files[fileName] + for _, d := range f.Imports { + p.parseImports(d) + } + for _, d := range f.Decls { + p.parseFileDeclarations(d) } } } @@ -267,7 +268,6 @@ func (p *ClassParser) parseFileDeclarations(node ast.Decl) { } func (p *ClassParser) handleFuncDecl(decl *ast.FuncDecl) { - if decl.Recv != nil { if decl.Recv.List == nil { return @@ -296,10 +296,18 @@ func (p *ClassParser) handleFuncDecl(decl *ast.FuncDecl) { } } -func handleGenDecStructType(p *ClassParser, typeName string, c *ast.StructType) { +func handleGenDecStructType(p *ClassParser, typeName string, c *ast.StructType, typeParams *ast.FieldList) { for _, f := range c.Fields.List { p.getOrCreateStruct(typeName).AddField(f, p.allImports) } + + if typeParams == nil { + return + } + + for _, tp := range typeParams.List { + p.getOrCreateStruct(typeName).AddTypeParam(tp) + } } func handleGenDecInterfaceType(p *ClassParser, typeName string, c *ast.InterfaceType) { @@ -307,13 +315,11 @@ func handleGenDecInterfaceType(p *ClassParser, typeName string, c *ast.Interface switch t := f.Type.(type) { case *ast.FuncType: p.getOrCreateStruct(typeName).AddMethod(f, p.allImports) - break case *ast.Ident: f, _ := getFieldType(t, p.allImports) st := p.getOrCreateStruct(typeName) f = replacePackageConstant(f, st.PackageName) st.AddToComposition(f) - break } } } @@ -338,7 +344,7 @@ func (p *ClassParser) processSpec(spec ast.Spec) { switch c := v.Type.(type) { case *ast.StructType: declarationType = "class" - handleGenDecStructType(p, typeName, c) + handleGenDecStructType(p, typeName, c, v.TypeParams) case *ast.InterfaceType: declarationType = "interface" handleGenDecInterfaceType(p, typeName, c) @@ -379,7 +385,6 @@ func (p *ClassParser) processSpec(spec ast.Spec) { p.allRenamedStructs[pack[0]][renamedClass] = pack[1] } } - return } // If this element is an array or a pointer, this function will return the type that is closer to these @@ -465,7 +470,7 @@ func (p *ClassParser) renderStructures(pack string, structures map[string]*Struc str.WriteLineWithDepth(2, aliasComplexNameComment) str.WriteLineWithDepth(1, "}") } - str.WriteLineWithDepth(0, fmt.Sprintf(`}`)) + str.WriteLineWithDepth(0, `}`) if p.renderingOptions.Compositions { str.WriteLineWithDepth(0, composition.String()) } @@ -479,7 +484,6 @@ func (p *ClassParser) renderStructures(pack string, structures map[string]*Struc } func (p *ClassParser) renderAliases(str *LineStringBuilder) { - aliasString := "" if p.renderingOptions.ConnectionLabels { aliasString = aliasOf @@ -505,7 +509,6 @@ func (p *ClassParser) renderAliases(str *LineStringBuilder) { } func (p *ClassParser) renderStructure(structure *Struct, pack string, name string, str *LineStringBuilder, composition *LineStringBuilder, extends *LineStringBuilder, aggregations *LineStringBuilder) { - privateFields := &LineStringBuilder{} publicFields := &LineStringBuilder{} privateMethods := &LineStringBuilder{} @@ -518,9 +521,24 @@ func (p *ClassParser) renderStructure(structure *Struct, pack string, name strin case "alias": sType = "<< (T, #FF7700) >> " renderStructureType = "class" + } + types := "" + if structure.Generics.exists() { + types = "<" + for t := range structure.Generics.Types { + types += fmt.Sprintf("%s, ", t) + } + types = strings.TrimSuffix(types, ", ") + types += " constrains " + for _, n := range structure.Generics.Names { + types += fmt.Sprintf("%s, ", n) + } + types = strings.TrimSuffix(types, ", ") + types += ">" } - str.WriteLineWithDepth(1, fmt.Sprintf(`%s %s %s {`, renderStructureType, name, sType)) + + str.WriteLineWithDepth(1, fmt.Sprintf(`%s %s%s %s {`, renderStructureType, name, types, sType)) p.renderStructFields(structure, privateFields, publicFields) p.renderStructMethods(structure, privateMethods, publicMethods) p.renderCompositions(structure, name, composition) @@ -538,7 +556,7 @@ func (p *ClassParser) renderStructure(structure *Struct, pack string, name strin if publicMethods.Len() > 0 { str.WriteLineWithDepth(0, publicMethods.String()) } - str.WriteLineWithDepth(1, fmt.Sprintf(`}`)) + str.WriteLineWithDepth(1, `}`) } func (p *ClassParser) renderCompositions(structure *Struct, name string, composition *LineStringBuilder) { @@ -562,7 +580,6 @@ func (p *ClassParser) renderCompositions(structure *Struct, name string, composi } func (p *ClassParser) renderAggregations(structure *Struct, name string, aggregations *LineStringBuilder) { - aggregationMap := structure.Aggregations if p.renderingOptions.AggregatePrivateMembers { p.updatePrivateAggregations(structure, aggregationMap) @@ -571,7 +588,6 @@ func (p *ClassParser) renderAggregations(structure *Struct, name string, aggrega } func (p *ClassParser) updatePrivateAggregations(structure *Struct, aggregationsMap map[string]struct{}) { - for agg := range structure.PrivateAggregations { aggregationsMap[agg] = struct{}{} } @@ -600,13 +616,13 @@ func (p *ClassParser) renderAggregationMap(aggregationMap map[string]struct{}, s } func (p *ClassParser) getPackageName(t string, st *Struct) string { - packageName := st.PackageName if isPrimitiveString(t) { packageName = builtinPackageName } return packageName } + func (p *ClassParser) renderExtends(structure *Struct, name string, extends *LineStringBuilder) { orderedExtends := []string{} @@ -628,7 +644,6 @@ func (p *ClassParser) renderExtends(structure *Struct, name string, extends *Lin } func (p *ClassParser) renderStructMethods(structure *Struct, privateMethods *LineStringBuilder, publicMethods *LineStringBuilder) { - for _, method := range structure.Functions { accessModifier := "+" if unicode.IsLower(rune(method.Name[0])) { @@ -685,6 +700,7 @@ func (p *ClassParser) getOrCreateStruct(name string) *Struct { Functions: make([]*Function, 0), Fields: make([]*Field, 0), Type: "", + Generics: NewGeneric(), Composition: make(map[string]struct{}, 0), Extends: make(map[string]struct{}, 0), Aggregations: make(map[string]struct{}, 0), diff --git a/parser/class_parser_test.go b/parser/class_parser_test.go index 8ca95bc..92b74b5 100644 --- a/parser/class_parser_test.go +++ b/parser/class_parser_test.go @@ -2,7 +2,7 @@ package parser import ( "go/ast" - "io/ioutil" + "os" "reflect" "testing" ) @@ -94,6 +94,7 @@ func TestGetOrCreateStruct(t *testing.T) { Functions: make([]*Function, 0), Fields: make([]*Field, 0), Type: "", + Generics: NewGeneric(), Composition: make(map[string]struct{}, 0), Extends: make(map[string]struct{}, 0), Aggregations: make(map[string]struct{}, 0), @@ -181,7 +182,6 @@ func TestRenderStructFields(t *testing.T) { } func TestRenderStructures(t *testing.T) { - structMap := map[string]*Struct{ "MainClass": getTestStruct(), } @@ -296,6 +296,7 @@ func getTestStruct() *Struct { ReturnValues: []string{"int"}, }, }, + Generics: NewGeneric(), } } @@ -563,7 +564,7 @@ func TestRender(t *testing.T) { }) resultRender := parser.Render() - result, err := ioutil.ReadFile("../testingsupport/testingsupport.puml") + result, err := os.ReadFile("../testingsupport/testingsupport.puml") if err != nil { t.Errorf("TestRender: expected no errors reading testing file, got %s", err.Error()) } @@ -592,7 +593,7 @@ func TestMultipleFolders(t *testing.T) { } resultRender := parser.Render() - result, err := ioutil.ReadFile("../testingsupport/subfolder1-2.puml") + result, err := os.ReadFile("../testingsupport/subfolder1-2.puml") if err != nil { t.Errorf("TestMultipleFolders: expected no errors reading testing file, got %s", err.Error()) } diff --git a/parser/field.go b/parser/field.go index fb25a88..6874538 100644 --- a/parser/field.go +++ b/parser/field.go @@ -9,15 +9,15 @@ import ( const packageConstant = "{packageName}" -//Field can hold the name and type of any field +// Field can hold the name and type of any field type Field struct { Name string Type string FullType string } -//Returns a string representation of the given expression if it was recognized. -//Refer to the implementation to see the different string representations. +// Returns a string representation of the given expression if it was recognized. +// Refer to the implementation to see the different string representations. func getFieldType(exp ast.Expr, aliases map[string]string) (string, []string) { switch v := exp.(type) { case *ast.Ident: @@ -40,12 +40,23 @@ func getFieldType(exp ast.Expr, aliases map[string]string) (string, []string) { return getFuncType(v, aliases) case *ast.Ellipsis: return getEllipsis(v, aliases) + case *ast.IndexExpr: + return getIndexExpr(v, aliases) + case *ast.IndexListExpr: + return getIndexListExpr(v, aliases) } return "", []string{} } -func getIdent(v *ast.Ident, aliases map[string]string) (string, []string) { +func getIndexExpr(v *ast.IndexExpr, aliases map[string]string) (string, []string) { + return getFieldType(v.X, aliases) +} + +func getIndexListExpr(v *ast.IndexListExpr, aliases map[string]string) (string, []string) { + return getFieldType(v.X, aliases) +} +func getIdent(v *ast.Ident, aliases map[string]string) (string, []string) { if isPrimitive(v) { return v.Name, []string{} } @@ -59,7 +70,6 @@ func getArrayType(v *ast.ArrayType, aliases map[string]string) (string, []string } func getSelectorExp(v *ast.SelectorExpr, aliases map[string]string) (string, []string) { - packageName := v.X.(*ast.Ident).Name if realPackageName, ok := aliases[packageName]; ok { packageName = realPackageName @@ -69,26 +79,22 @@ func getSelectorExp(v *ast.SelectorExpr, aliases map[string]string) (string, []s } func getMapType(v *ast.MapType, aliases map[string]string) (string, []string) { - t1, f1 := getFieldType(v.Key, aliases) t2, f2 := getFieldType(v.Value, aliases) return fmt.Sprintf("map[%s]%s", t1, t2), append(f1, f2...) } func getStarExp(v *ast.StarExpr, aliases map[string]string) (string, []string) { - t, f := getFieldType(v.X, aliases) return fmt.Sprintf("*%s", t), f } func getChanType(v *ast.ChanType, aliases map[string]string) (string, []string) { - t, f := getFieldType(v.Value, aliases) return fmt.Sprintf("chan %s", t), f } func getStructType(v *ast.StructType, aliases map[string]string) (string, []string) { - fieldList := make([]string, 0) for _, field := range v.Fields.List { t, _ := getFieldType(field.Type, aliases) @@ -98,7 +104,6 @@ func getStructType(v *ast.StructType, aliases map[string]string) (string, []stri } func getInterfaceType(v *ast.InterfaceType, aliases map[string]string) (string, []string) { - methods := make([]string, 0) for _, field := range v.Methods.List { methodName := "" @@ -112,7 +117,6 @@ func getInterfaceType(v *ast.InterfaceType, aliases map[string]string) (string, } func getFuncType(v *ast.FuncType, aliases map[string]string) (string, []string) { - function := getFunction(v, "", aliases, "") params := make([]string, 0) for _, pa := range function.Parameters { @@ -120,9 +124,7 @@ func getFuncType(v *ast.FuncType, aliases map[string]string) (string, []string) } returns := "" returnList := make([]string, 0) - for _, re := range function.ReturnValues { - returnList = append(returnList, re) - } + returnList = append(returnList, function.ReturnValues...) if len(returnList) > 1 { returns = fmt.Sprintf("(%s)", strings.Join(returnList, ", ")) } else { diff --git a/parser/struct.go b/parser/struct.go index 1383775..c4ac29e 100644 --- a/parser/struct.go +++ b/parser/struct.go @@ -5,13 +5,14 @@ import ( "unicode" ) -//Struct represent a struct in golang, it can be of Type "class" or "interface" and can be associated -//with other structs via Composition and Extends +// Struct represent a struct in golang, it can be of Type "class" or "interface" and can be associated +// with other structs via Composition and Extends type Struct struct { PackageName string Functions []*Function Fields []*Field Type string + Generics *Generic Composition map[string]struct{} Extends map[string]struct{} Aggregations map[string]struct{} @@ -38,9 +39,9 @@ func (st *Struct) ImplementsInterface(inter *Struct) bool { return true } -//AddToComposition adds the composition relation to the structure. We want to make sure that *ExampleStruct -//gets added as ExampleStruct so that we can properly build the relation later to the -//class identifier +// AddToComposition adds the composition relation to the structure. We want to make sure that *ExampleStruct +// gets added as ExampleStruct so that we can properly build the relation later to the +// class identifier func (st *Struct) AddToComposition(fType string) { if len(fType) == 0 { return @@ -51,9 +52,9 @@ func (st *Struct) AddToComposition(fType string) { st.Composition[fType] = struct{}{} } -//AddToExtends Adds an extends relationship to this struct. We want to make sure that *ExampleStruct -//gets added as ExampleStruct so that we can properly build the relation later to the -//class identifier +// AddToExtends Adds an extends relationship to this struct. We want to make sure that *ExampleStruct +// gets added as ExampleStruct so that we can properly build the relation later to the +// class identifier func (st *Struct) AddToExtends(fType string) { if len(fType) == 0 { return @@ -64,18 +65,18 @@ func (st *Struct) AddToExtends(fType string) { st.Extends[fType] = struct{}{} } -//AddToAggregation adds an aggregation type to the list of aggregations +// AddToAggregation adds an aggregation type to the list of aggregations func (st *Struct) AddToAggregation(fType string) { st.Aggregations[fType] = struct{}{} } -//addToPrivateAggregation adds an aggregation type to the list of aggregations for private members +// addToPrivateAggregation adds an aggregation type to the list of aggregations for private members func (st *Struct) addToPrivateAggregation(fType string) { st.PrivateAggregations[fType] = struct{}{} } -//AddField adds a field into this structure. It parses the ast.Field and extract all -//needed information +// AddField adds a field into this structure. It parses the ast.Field and extract all +// needed information func (st *Struct) AddField(field *ast.Field, aliases map[string]string) { theType, fundamentalTypes := getFieldType(field.Type, aliases) theType = replacePackageConstant(theType, "") @@ -103,7 +104,11 @@ func (st *Struct) AddField(field *ast.Field, aliases map[string]string) { } } -//AddMethod Parse the Field and if it is an ast.FuncType, then add the methods into the structure +func (st *Struct) AddTypeParam(field *ast.Field) { + st.Generics.getNames(field).getTypes(field) +} + +// AddMethod Parse the Field and if it is an ast.FuncType, then add the methods into the structure func (st *Struct) AddMethod(method *ast.Field, aliases map[string]string) { f, ok := method.Type.(*ast.FuncType) if !ok { diff --git a/parser/struct_test.go b/parser/struct_test.go index 9ae7013..c8e9d58 100644 --- a/parser/struct_test.go +++ b/parser/struct_test.go @@ -326,7 +326,6 @@ func TestAddToExtension(t *testing.T) { } func arrayContains(a map[string]struct{}, text string) bool { - found := false for v := range a { if v == text { diff --git a/parser/type.go b/parser/type.go new file mode 100644 index 0000000..32e3158 --- /dev/null +++ b/parser/type.go @@ -0,0 +1,78 @@ +package parser + +import ( + "go/ast" +) + +type Generic struct { + Names []string + Types map[string]struct{} // Keep a set to cover the case when there are multiple type names using the same type. +} + +func NewGeneric() *Generic { + return &Generic{Types: make(map[string]struct{})} +} + +func (g *Generic) exists() bool { + return len(g.Names) != 0 && len(g.Types) != 0 +} + +func (g *Generic) getNames(field *ast.Field) *Generic { + for _, name := range field.Names { + g.Names = append(g.Names, name.String()) + } + + return g +} + +func (g *Generic) getTypes(field *ast.Field) *Generic { + switch f := field.Type.(type) { + case *ast.Ident: + g.Types[f.Name] = struct{}{} + case *ast.BinaryExpr: + switch x := f.X.(type) { + case *ast.Ident: + g.Types[x.Name] = struct{}{} + } + + switch y := f.Y.(type) { + case *ast.Ident: + g.Types[y.Name] = struct{}{} + } + + // The below, while ugly, handles scenarios where we have N binary expressions. + // An example of this could be + // type foo[T string | bool | int | int16 | float64] struct{} + switch f.X.(type) { + case *ast.BinaryExpr: + var x ast.Expr = f.X.(*ast.BinaryExpr) + var process = true + for process { + switch xt := x.(type) { + case *ast.Ident: + g.Types[xt.Name] = struct{}{} + } + + switch yt := x.(type) { + case *ast.Ident: + g.Types[yt.Name] = struct{}{} + case *ast.BinaryExpr: + switch ytt := yt.Y.(type) { + case *ast.Ident: + g.Types[ytt.Name] = struct{}{} + } + } + + newX, safe := x.(*ast.BinaryExpr) + process = safe + if safe { + x = newX.X + } + } + } + case *ast.InterfaceType: + g.Types["interface"] = struct{}{} + } + + return g +} diff --git a/parser/type_test.go b/parser/type_test.go new file mode 100644 index 0000000..bb093de --- /dev/null +++ b/parser/type_test.go @@ -0,0 +1,82 @@ +package parser + +import ( + "go/ast" + "reflect" + "testing" +) + +func TestNewGeneric(t *testing.T) { + g := NewGeneric() + if g == nil { + t.Fatal("Returned value should not be nil") + } +} + +func TestExists(t *testing.T) { + g := NewGeneric() + if g.exists() { + t.Fatal("Should not exist at this point") + } + + g.getNames(&ast.Field{ + Names: []*ast.Ident{{Name: "test"}}, + }).getTypes(&ast.Field{ + Type: new(ast.Ident), + }) + + if !g.exists() { + t.Fatal("Should exist at this point") + } +} + +func TestGetTypes(t *testing.T) { + table := []struct { + names []*ast.Ident + typ func() ast.Expr + expectedNames []string + expectedTypes map[string]struct{} + }{{ + names: []*ast.Ident{{Name: "any"}}, + typ: func() ast.Expr { + e := new(ast.Ident) + e.Name = "any" + return e + }, + expectedNames: []string{"any"}, + expectedTypes: map[string]struct{}{"any": {}}, + }, { + names: []*ast.Ident{{Name: "int"}}, + typ: func() ast.Expr { + e := new(ast.BinaryExpr) + e.X = new(ast.Ident) + e.X.(*ast.Ident).Name = "int" + e.Y = new(ast.Ident) + e.Y.(*ast.Ident).Name = "bool" + return e + }, + expectedNames: []string{"int"}, + expectedTypes: map[string]struct{}{"int": {}, "bool": {}}, + }, { + names: []*ast.Ident{{Name: "interface"}}, + typ: func() ast.Expr { + e := new(ast.InterfaceType) + return e + }, + expectedNames: []string{"interface"}, + expectedTypes: map[string]struct{}{"interface": {}}, + }} + + for _, entry := range table { + g := NewGeneric() + field := &ast.Field{Names: entry.names, Type: entry.typ()} + g.getNames(field).getTypes(field) + if !reflect.DeepEqual(g.Names, entry.expectedNames) { + t.Errorf("Mismatched names: %v %v", g.Names, entry.expectedNames) + } + + if !reflect.DeepEqual(g.Types, entry.expectedTypes) { + t.Errorf("Mismatched types: %v %v", g.Types, entry.expectedTypes) + } + } +} diff --git a/testingsupport/generics/generics.go b/testingsupport/generics/generics.go new file mode 100644 index 0000000..afac0b6 --- /dev/null +++ b/testingsupport/generics/generics.go @@ -0,0 +1,25 @@ +package generics + +type SingleAny[T any] struct{} +type SingleString[T string] struct{} +type SingleFloat32[T float32] struct{} +type SingleFloat64[T float64] struct{} +type SingleInt[T int] struct{} +type SingleInt16[T int16] struct{} +type SingleInt32[T int32] struct{} +type SingleInt64[T int64] struct{} +type SingleBool[T bool] struct{} + +type OrAny[T any | any] struct{} +type OrMixed[T string | bool] struct{} +type ManyOrAny[T any | any | any | any] struct{} +type ManyOrMixed[T string | bool | int | int16] struct{} + +type MultipleAny[T any, K any] struct{} +type MultipleAnyOneType[T, K any] struct{} +type MultipleString[T string, K string] struct{} +type MultipleStringOneType[T, K string] struct{} + +type AnonIface[T interface{}] struct{} +type named interface{} +type NamedIface[T named] struct{}