diff --git a/go.mod b/go.mod
index 63a30d2..212fa9c 100644
--- a/go.mod
+++ b/go.mod
@@ -1,8 +1,7 @@
module github.com/jfeliu007/goplantuml
-go 1.17
+go 1.18
-require (
- github.com/spf13/afero v1.8.2
- golang.org/x/text v0.3.7 // indirect
-)
+require github.com/spf13/afero v1.8.2
+
+require golang.org/x/text v0.3.7 // indirect
diff --git a/parser/class_parser.go b/parser/class_parser.go
index 7526b00..491d68d 100644
--- a/parser/class_parser.go
+++ b/parser/class_parser.go
@@ -10,7 +10,6 @@ call the Render() function and this will return a string with the class diagram.
See github.com/jfeliu007/goplantuml/cmd/goplantuml/main.go for a command that uses this functions and outputs the text to
the console.
-
*/
package parser
@@ -221,17 +220,19 @@ func (p *ClassParser) parsePackage(node ast.Node) {
for fileName := range pack.Files {
sortedFiles = append(sortedFiles, fileName)
}
+
sort.Strings(sortedFiles)
for _, fileName := range sortedFiles {
+ if strings.HasSuffix(fileName, "_test.go") {
+ continue
+ }
- if !strings.HasSuffix(fileName, "_test.go") {
- f := pack.Files[fileName]
- for _, d := range f.Imports {
- p.parseImports(d)
- }
- for _, d := range f.Decls {
- p.parseFileDeclarations(d)
- }
+ f := pack.Files[fileName]
+ for _, d := range f.Imports {
+ p.parseImports(d)
+ }
+ for _, d := range f.Decls {
+ p.parseFileDeclarations(d)
}
}
}
@@ -267,7 +268,6 @@ func (p *ClassParser) parseFileDeclarations(node ast.Decl) {
}
func (p *ClassParser) handleFuncDecl(decl *ast.FuncDecl) {
-
if decl.Recv != nil {
if decl.Recv.List == nil {
return
@@ -296,10 +296,18 @@ func (p *ClassParser) handleFuncDecl(decl *ast.FuncDecl) {
}
}
-func handleGenDecStructType(p *ClassParser, typeName string, c *ast.StructType) {
+func handleGenDecStructType(p *ClassParser, typeName string, c *ast.StructType, typeParams *ast.FieldList) {
for _, f := range c.Fields.List {
p.getOrCreateStruct(typeName).AddField(f, p.allImports)
}
+
+ if typeParams == nil {
+ return
+ }
+
+ for _, tp := range typeParams.List {
+ p.getOrCreateStruct(typeName).AddTypeParam(tp)
+ }
}
func handleGenDecInterfaceType(p *ClassParser, typeName string, c *ast.InterfaceType) {
@@ -307,13 +315,11 @@ func handleGenDecInterfaceType(p *ClassParser, typeName string, c *ast.Interface
switch t := f.Type.(type) {
case *ast.FuncType:
p.getOrCreateStruct(typeName).AddMethod(f, p.allImports)
- break
case *ast.Ident:
f, _ := getFieldType(t, p.allImports)
st := p.getOrCreateStruct(typeName)
f = replacePackageConstant(f, st.PackageName)
st.AddToComposition(f)
- break
}
}
}
@@ -338,7 +344,7 @@ func (p *ClassParser) processSpec(spec ast.Spec) {
switch c := v.Type.(type) {
case *ast.StructType:
declarationType = "class"
- handleGenDecStructType(p, typeName, c)
+ handleGenDecStructType(p, typeName, c, v.TypeParams)
case *ast.InterfaceType:
declarationType = "interface"
handleGenDecInterfaceType(p, typeName, c)
@@ -379,7 +385,6 @@ func (p *ClassParser) processSpec(spec ast.Spec) {
p.allRenamedStructs[pack[0]][renamedClass] = pack[1]
}
}
- return
}
// If this element is an array or a pointer, this function will return the type that is closer to these
@@ -465,7 +470,7 @@ func (p *ClassParser) renderStructures(pack string, structures map[string]*Struc
str.WriteLineWithDepth(2, aliasComplexNameComment)
str.WriteLineWithDepth(1, "}")
}
- str.WriteLineWithDepth(0, fmt.Sprintf(`}`))
+ str.WriteLineWithDepth(0, `}`)
if p.renderingOptions.Compositions {
str.WriteLineWithDepth(0, composition.String())
}
@@ -479,7 +484,6 @@ func (p *ClassParser) renderStructures(pack string, structures map[string]*Struc
}
func (p *ClassParser) renderAliases(str *LineStringBuilder) {
-
aliasString := ""
if p.renderingOptions.ConnectionLabels {
aliasString = aliasOf
@@ -505,7 +509,6 @@ func (p *ClassParser) renderAliases(str *LineStringBuilder) {
}
func (p *ClassParser) renderStructure(structure *Struct, pack string, name string, str *LineStringBuilder, composition *LineStringBuilder, extends *LineStringBuilder, aggregations *LineStringBuilder) {
-
privateFields := &LineStringBuilder{}
publicFields := &LineStringBuilder{}
privateMethods := &LineStringBuilder{}
@@ -518,9 +521,24 @@ func (p *ClassParser) renderStructure(structure *Struct, pack string, name strin
case "alias":
sType = "<< (T, #FF7700) >> "
renderStructureType = "class"
+ }
+ types := ""
+ if structure.Generics.exists() {
+ types = "<"
+ for t := range structure.Generics.Types {
+ types += fmt.Sprintf("%s, ", t)
+ }
+ types = strings.TrimSuffix(types, ", ")
+ types += " constrains "
+ for _, n := range structure.Generics.Names {
+ types += fmt.Sprintf("%s, ", n)
+ }
+ types = strings.TrimSuffix(types, ", ")
+ types += ">"
}
- str.WriteLineWithDepth(1, fmt.Sprintf(`%s %s %s {`, renderStructureType, name, sType))
+
+ str.WriteLineWithDepth(1, fmt.Sprintf(`%s %s%s %s {`, renderStructureType, name, types, sType))
p.renderStructFields(structure, privateFields, publicFields)
p.renderStructMethods(structure, privateMethods, publicMethods)
p.renderCompositions(structure, name, composition)
@@ -538,7 +556,7 @@ func (p *ClassParser) renderStructure(structure *Struct, pack string, name strin
if publicMethods.Len() > 0 {
str.WriteLineWithDepth(0, publicMethods.String())
}
- str.WriteLineWithDepth(1, fmt.Sprintf(`}`))
+ str.WriteLineWithDepth(1, `}`)
}
func (p *ClassParser) renderCompositions(structure *Struct, name string, composition *LineStringBuilder) {
@@ -562,7 +580,6 @@ func (p *ClassParser) renderCompositions(structure *Struct, name string, composi
}
func (p *ClassParser) renderAggregations(structure *Struct, name string, aggregations *LineStringBuilder) {
-
aggregationMap := structure.Aggregations
if p.renderingOptions.AggregatePrivateMembers {
p.updatePrivateAggregations(structure, aggregationMap)
@@ -571,7 +588,6 @@ func (p *ClassParser) renderAggregations(structure *Struct, name string, aggrega
}
func (p *ClassParser) updatePrivateAggregations(structure *Struct, aggregationsMap map[string]struct{}) {
-
for agg := range structure.PrivateAggregations {
aggregationsMap[agg] = struct{}{}
}
@@ -600,13 +616,13 @@ func (p *ClassParser) renderAggregationMap(aggregationMap map[string]struct{}, s
}
func (p *ClassParser) getPackageName(t string, st *Struct) string {
-
packageName := st.PackageName
if isPrimitiveString(t) {
packageName = builtinPackageName
}
return packageName
}
+
func (p *ClassParser) renderExtends(structure *Struct, name string, extends *LineStringBuilder) {
orderedExtends := []string{}
@@ -628,7 +644,6 @@ func (p *ClassParser) renderExtends(structure *Struct, name string, extends *Lin
}
func (p *ClassParser) renderStructMethods(structure *Struct, privateMethods *LineStringBuilder, publicMethods *LineStringBuilder) {
-
for _, method := range structure.Functions {
accessModifier := "+"
if unicode.IsLower(rune(method.Name[0])) {
@@ -685,6 +700,7 @@ func (p *ClassParser) getOrCreateStruct(name string) *Struct {
Functions: make([]*Function, 0),
Fields: make([]*Field, 0),
Type: "",
+ Generics: NewGeneric(),
Composition: make(map[string]struct{}, 0),
Extends: make(map[string]struct{}, 0),
Aggregations: make(map[string]struct{}, 0),
diff --git a/parser/class_parser_test.go b/parser/class_parser_test.go
index 8ca95bc..92b74b5 100644
--- a/parser/class_parser_test.go
+++ b/parser/class_parser_test.go
@@ -2,7 +2,7 @@ package parser
import (
"go/ast"
- "io/ioutil"
+ "os"
"reflect"
"testing"
)
@@ -94,6 +94,7 @@ func TestGetOrCreateStruct(t *testing.T) {
Functions: make([]*Function, 0),
Fields: make([]*Field, 0),
Type: "",
+ Generics: NewGeneric(),
Composition: make(map[string]struct{}, 0),
Extends: make(map[string]struct{}, 0),
Aggregations: make(map[string]struct{}, 0),
@@ -181,7 +182,6 @@ func TestRenderStructFields(t *testing.T) {
}
func TestRenderStructures(t *testing.T) {
-
structMap := map[string]*Struct{
"MainClass": getTestStruct(),
}
@@ -296,6 +296,7 @@ func getTestStruct() *Struct {
ReturnValues: []string{"int"},
},
},
+ Generics: NewGeneric(),
}
}
@@ -563,7 +564,7 @@ func TestRender(t *testing.T) {
})
resultRender := parser.Render()
- result, err := ioutil.ReadFile("../testingsupport/testingsupport.puml")
+ result, err := os.ReadFile("../testingsupport/testingsupport.puml")
if err != nil {
t.Errorf("TestRender: expected no errors reading testing file, got %s", err.Error())
}
@@ -592,7 +593,7 @@ func TestMultipleFolders(t *testing.T) {
}
resultRender := parser.Render()
- result, err := ioutil.ReadFile("../testingsupport/subfolder1-2.puml")
+ result, err := os.ReadFile("../testingsupport/subfolder1-2.puml")
if err != nil {
t.Errorf("TestMultipleFolders: expected no errors reading testing file, got %s", err.Error())
}
diff --git a/parser/field.go b/parser/field.go
index fb25a88..6874538 100644
--- a/parser/field.go
+++ b/parser/field.go
@@ -9,15 +9,15 @@ import (
const packageConstant = "{packageName}"
-//Field can hold the name and type of any field
+// Field can hold the name and type of any field
type Field struct {
Name string
Type string
FullType string
}
-//Returns a string representation of the given expression if it was recognized.
-//Refer to the implementation to see the different string representations.
+// Returns a string representation of the given expression if it was recognized.
+// Refer to the implementation to see the different string representations.
func getFieldType(exp ast.Expr, aliases map[string]string) (string, []string) {
switch v := exp.(type) {
case *ast.Ident:
@@ -40,12 +40,23 @@ func getFieldType(exp ast.Expr, aliases map[string]string) (string, []string) {
return getFuncType(v, aliases)
case *ast.Ellipsis:
return getEllipsis(v, aliases)
+ case *ast.IndexExpr:
+ return getIndexExpr(v, aliases)
+ case *ast.IndexListExpr:
+ return getIndexListExpr(v, aliases)
}
return "", []string{}
}
-func getIdent(v *ast.Ident, aliases map[string]string) (string, []string) {
+func getIndexExpr(v *ast.IndexExpr, aliases map[string]string) (string, []string) {
+ return getFieldType(v.X, aliases)
+}
+
+func getIndexListExpr(v *ast.IndexListExpr, aliases map[string]string) (string, []string) {
+ return getFieldType(v.X, aliases)
+}
+func getIdent(v *ast.Ident, aliases map[string]string) (string, []string) {
if isPrimitive(v) {
return v.Name, []string{}
}
@@ -59,7 +70,6 @@ func getArrayType(v *ast.ArrayType, aliases map[string]string) (string, []string
}
func getSelectorExp(v *ast.SelectorExpr, aliases map[string]string) (string, []string) {
-
packageName := v.X.(*ast.Ident).Name
if realPackageName, ok := aliases[packageName]; ok {
packageName = realPackageName
@@ -69,26 +79,22 @@ func getSelectorExp(v *ast.SelectorExpr, aliases map[string]string) (string, []s
}
func getMapType(v *ast.MapType, aliases map[string]string) (string, []string) {
-
t1, f1 := getFieldType(v.Key, aliases)
t2, f2 := getFieldType(v.Value, aliases)
return fmt.Sprintf("map[%s]%s", t1, t2), append(f1, f2...)
}
func getStarExp(v *ast.StarExpr, aliases map[string]string) (string, []string) {
-
t, f := getFieldType(v.X, aliases)
return fmt.Sprintf("*%s", t), f
}
func getChanType(v *ast.ChanType, aliases map[string]string) (string, []string) {
-
t, f := getFieldType(v.Value, aliases)
return fmt.Sprintf("chan %s", t), f
}
func getStructType(v *ast.StructType, aliases map[string]string) (string, []string) {
-
fieldList := make([]string, 0)
for _, field := range v.Fields.List {
t, _ := getFieldType(field.Type, aliases)
@@ -98,7 +104,6 @@ func getStructType(v *ast.StructType, aliases map[string]string) (string, []stri
}
func getInterfaceType(v *ast.InterfaceType, aliases map[string]string) (string, []string) {
-
methods := make([]string, 0)
for _, field := range v.Methods.List {
methodName := ""
@@ -112,7 +117,6 @@ func getInterfaceType(v *ast.InterfaceType, aliases map[string]string) (string,
}
func getFuncType(v *ast.FuncType, aliases map[string]string) (string, []string) {
-
function := getFunction(v, "", aliases, "")
params := make([]string, 0)
for _, pa := range function.Parameters {
@@ -120,9 +124,7 @@ func getFuncType(v *ast.FuncType, aliases map[string]string) (string, []string)
}
returns := ""
returnList := make([]string, 0)
- for _, re := range function.ReturnValues {
- returnList = append(returnList, re)
- }
+ returnList = append(returnList, function.ReturnValues...)
if len(returnList) > 1 {
returns = fmt.Sprintf("(%s)", strings.Join(returnList, ", "))
} else {
diff --git a/parser/struct.go b/parser/struct.go
index 1383775..c4ac29e 100644
--- a/parser/struct.go
+++ b/parser/struct.go
@@ -5,13 +5,14 @@ import (
"unicode"
)
-//Struct represent a struct in golang, it can be of Type "class" or "interface" and can be associated
-//with other structs via Composition and Extends
+// Struct represent a struct in golang, it can be of Type "class" or "interface" and can be associated
+// with other structs via Composition and Extends
type Struct struct {
PackageName string
Functions []*Function
Fields []*Field
Type string
+ Generics *Generic
Composition map[string]struct{}
Extends map[string]struct{}
Aggregations map[string]struct{}
@@ -38,9 +39,9 @@ func (st *Struct) ImplementsInterface(inter *Struct) bool {
return true
}
-//AddToComposition adds the composition relation to the structure. We want to make sure that *ExampleStruct
-//gets added as ExampleStruct so that we can properly build the relation later to the
-//class identifier
+// AddToComposition adds the composition relation to the structure. We want to make sure that *ExampleStruct
+// gets added as ExampleStruct so that we can properly build the relation later to the
+// class identifier
func (st *Struct) AddToComposition(fType string) {
if len(fType) == 0 {
return
@@ -51,9 +52,9 @@ func (st *Struct) AddToComposition(fType string) {
st.Composition[fType] = struct{}{}
}
-//AddToExtends Adds an extends relationship to this struct. We want to make sure that *ExampleStruct
-//gets added as ExampleStruct so that we can properly build the relation later to the
-//class identifier
+// AddToExtends Adds an extends relationship to this struct. We want to make sure that *ExampleStruct
+// gets added as ExampleStruct so that we can properly build the relation later to the
+// class identifier
func (st *Struct) AddToExtends(fType string) {
if len(fType) == 0 {
return
@@ -64,18 +65,18 @@ func (st *Struct) AddToExtends(fType string) {
st.Extends[fType] = struct{}{}
}
-//AddToAggregation adds an aggregation type to the list of aggregations
+// AddToAggregation adds an aggregation type to the list of aggregations
func (st *Struct) AddToAggregation(fType string) {
st.Aggregations[fType] = struct{}{}
}
-//addToPrivateAggregation adds an aggregation type to the list of aggregations for private members
+// addToPrivateAggregation adds an aggregation type to the list of aggregations for private members
func (st *Struct) addToPrivateAggregation(fType string) {
st.PrivateAggregations[fType] = struct{}{}
}
-//AddField adds a field into this structure. It parses the ast.Field and extract all
-//needed information
+// AddField adds a field into this structure. It parses the ast.Field and extract all
+// needed information
func (st *Struct) AddField(field *ast.Field, aliases map[string]string) {
theType, fundamentalTypes := getFieldType(field.Type, aliases)
theType = replacePackageConstant(theType, "")
@@ -103,7 +104,11 @@ func (st *Struct) AddField(field *ast.Field, aliases map[string]string) {
}
}
-//AddMethod Parse the Field and if it is an ast.FuncType, then add the methods into the structure
+func (st *Struct) AddTypeParam(field *ast.Field) {
+ st.Generics.getNames(field).getTypes(field)
+}
+
+// AddMethod Parse the Field and if it is an ast.FuncType, then add the methods into the structure
func (st *Struct) AddMethod(method *ast.Field, aliases map[string]string) {
f, ok := method.Type.(*ast.FuncType)
if !ok {
diff --git a/parser/struct_test.go b/parser/struct_test.go
index 9ae7013..c8e9d58 100644
--- a/parser/struct_test.go
+++ b/parser/struct_test.go
@@ -326,7 +326,6 @@ func TestAddToExtension(t *testing.T) {
}
func arrayContains(a map[string]struct{}, text string) bool {
-
found := false
for v := range a {
if v == text {
diff --git a/parser/type.go b/parser/type.go
new file mode 100644
index 0000000..32e3158
--- /dev/null
+++ b/parser/type.go
@@ -0,0 +1,78 @@
+package parser
+
+import (
+ "go/ast"
+)
+
+type Generic struct {
+ Names []string
+ Types map[string]struct{} // Keep a set to cover the case when there are multiple type names using the same type.
+}
+
+func NewGeneric() *Generic {
+ return &Generic{Types: make(map[string]struct{})}
+}
+
+func (g *Generic) exists() bool {
+ return len(g.Names) != 0 && len(g.Types) != 0
+}
+
+func (g *Generic) getNames(field *ast.Field) *Generic {
+ for _, name := range field.Names {
+ g.Names = append(g.Names, name.String())
+ }
+
+ return g
+}
+
+func (g *Generic) getTypes(field *ast.Field) *Generic {
+ switch f := field.Type.(type) {
+ case *ast.Ident:
+ g.Types[f.Name] = struct{}{}
+ case *ast.BinaryExpr:
+ switch x := f.X.(type) {
+ case *ast.Ident:
+ g.Types[x.Name] = struct{}{}
+ }
+
+ switch y := f.Y.(type) {
+ case *ast.Ident:
+ g.Types[y.Name] = struct{}{}
+ }
+
+ // The below, while ugly, handles scenarios where we have N binary expressions.
+ // An example of this could be
+ // type foo[T string | bool | int | int16 | float64] struct{}
+ switch f.X.(type) {
+ case *ast.BinaryExpr:
+ var x ast.Expr = f.X.(*ast.BinaryExpr)
+ var process = true
+ for process {
+ switch xt := x.(type) {
+ case *ast.Ident:
+ g.Types[xt.Name] = struct{}{}
+ }
+
+ switch yt := x.(type) {
+ case *ast.Ident:
+ g.Types[yt.Name] = struct{}{}
+ case *ast.BinaryExpr:
+ switch ytt := yt.Y.(type) {
+ case *ast.Ident:
+ g.Types[ytt.Name] = struct{}{}
+ }
+ }
+
+ newX, safe := x.(*ast.BinaryExpr)
+ process = safe
+ if safe {
+ x = newX.X
+ }
+ }
+ }
+ case *ast.InterfaceType:
+ g.Types["interface"] = struct{}{}
+ }
+
+ return g
+}
diff --git a/parser/type_test.go b/parser/type_test.go
new file mode 100644
index 0000000..bb093de
--- /dev/null
+++ b/parser/type_test.go
@@ -0,0 +1,82 @@
+package parser
+
+import (
+ "go/ast"
+ "reflect"
+ "testing"
+)
+
+func TestNewGeneric(t *testing.T) {
+ g := NewGeneric()
+ if g == nil {
+ t.Fatal("Returned value should not be nil")
+ }
+}
+
+func TestExists(t *testing.T) {
+ g := NewGeneric()
+ if g.exists() {
+ t.Fatal("Should not exist at this point")
+ }
+
+ g.getNames(&ast.Field{
+ Names: []*ast.Ident{{Name: "test"}},
+ }).getTypes(&ast.Field{
+ Type: new(ast.Ident),
+ })
+
+ if !g.exists() {
+ t.Fatal("Should exist at this point")
+ }
+}
+
+func TestGetTypes(t *testing.T) {
+ table := []struct {
+ names []*ast.Ident
+ typ func() ast.Expr
+ expectedNames []string
+ expectedTypes map[string]struct{}
+ }{{
+ names: []*ast.Ident{{Name: "any"}},
+ typ: func() ast.Expr {
+ e := new(ast.Ident)
+ e.Name = "any"
+ return e
+ },
+ expectedNames: []string{"any"},
+ expectedTypes: map[string]struct{}{"any": {}},
+ }, {
+ names: []*ast.Ident{{Name: "int"}},
+ typ: func() ast.Expr {
+ e := new(ast.BinaryExpr)
+ e.X = new(ast.Ident)
+ e.X.(*ast.Ident).Name = "int"
+ e.Y = new(ast.Ident)
+ e.Y.(*ast.Ident).Name = "bool"
+ return e
+ },
+ expectedNames: []string{"int"},
+ expectedTypes: map[string]struct{}{"int": {}, "bool": {}},
+ }, {
+ names: []*ast.Ident{{Name: "interface"}},
+ typ: func() ast.Expr {
+ e := new(ast.InterfaceType)
+ return e
+ },
+ expectedNames: []string{"interface"},
+ expectedTypes: map[string]struct{}{"interface": {}},
+ }}
+
+ for _, entry := range table {
+ g := NewGeneric()
+ field := &ast.Field{Names: entry.names, Type: entry.typ()}
+ g.getNames(field).getTypes(field)
+ if !reflect.DeepEqual(g.Names, entry.expectedNames) {
+ t.Errorf("Mismatched names: %v %v", g.Names, entry.expectedNames)
+ }
+
+ if !reflect.DeepEqual(g.Types, entry.expectedTypes) {
+ t.Errorf("Mismatched types: %v %v", g.Types, entry.expectedTypes)
+ }
+ }
+}
diff --git a/testingsupport/generics/generics.go b/testingsupport/generics/generics.go
new file mode 100644
index 0000000..afac0b6
--- /dev/null
+++ b/testingsupport/generics/generics.go
@@ -0,0 +1,25 @@
+package generics
+
+type SingleAny[T any] struct{}
+type SingleString[T string] struct{}
+type SingleFloat32[T float32] struct{}
+type SingleFloat64[T float64] struct{}
+type SingleInt[T int] struct{}
+type SingleInt16[T int16] struct{}
+type SingleInt32[T int32] struct{}
+type SingleInt64[T int64] struct{}
+type SingleBool[T bool] struct{}
+
+type OrAny[T any | any] struct{}
+type OrMixed[T string | bool] struct{}
+type ManyOrAny[T any | any | any | any] struct{}
+type ManyOrMixed[T string | bool | int | int16] struct{}
+
+type MultipleAny[T any, K any] struct{}
+type MultipleAnyOneType[T, K any] struct{}
+type MultipleString[T string, K string] struct{}
+type MultipleStringOneType[T, K string] struct{}
+
+type AnonIface[T interface{}] struct{}
+type named interface{}
+type NamedIface[T named] struct{}