Skip to content

Commit fd46275

Browse files
committed
Support generic struct types
1 parent c585660 commit fd46275

File tree

9 files changed

+246
-53
lines changed

9 files changed

+246
-53
lines changed

go.mod

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
module github.com/jfeliu007/goplantuml
22

3-
go 1.17
3+
go 1.18
44

5-
require (
6-
github.com/spf13/afero v1.8.2
7-
golang.org/x/text v0.3.7 // indirect
8-
)
5+
require github.com/spf13/afero v1.8.2
6+
7+
require golang.org/x/text v0.3.7 // indirect

parser/class_parser.go

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ call the Render() function and this will return a string with the class diagram.
1010
1111
See github.com/jfeliu007/goplantuml/cmd/goplantuml/main.go for a command that uses this functions and outputs the text to
1212
the console.
13-
1413
*/
1514
package parser
1615

@@ -267,7 +266,6 @@ func (p *ClassParser) parseFileDeclarations(node ast.Decl) {
267266
}
268267

269268
func (p *ClassParser) handleFuncDecl(decl *ast.FuncDecl) {
270-
271269
if decl.Recv != nil {
272270
if decl.Recv.List == nil {
273271
return
@@ -296,24 +294,30 @@ func (p *ClassParser) handleFuncDecl(decl *ast.FuncDecl) {
296294
}
297295
}
298296

299-
func handleGenDecStructType(p *ClassParser, typeName string, c *ast.StructType) {
297+
func handleGenDecStructType(p *ClassParser, typeName string, c *ast.StructType, typeParams *ast.FieldList) {
300298
for _, f := range c.Fields.List {
301299
p.getOrCreateStruct(typeName).AddField(f, p.allImports)
302300
}
301+
302+
if typeParams == nil {
303+
return
304+
}
305+
306+
for _, tp := range typeParams.List {
307+
p.getOrCreateStruct(typeName).AddTypeParam(tp)
308+
}
303309
}
304310

305311
func handleGenDecInterfaceType(p *ClassParser, typeName string, c *ast.InterfaceType) {
306312
for _, f := range c.Methods.List {
307313
switch t := f.Type.(type) {
308314
case *ast.FuncType:
309315
p.getOrCreateStruct(typeName).AddMethod(f, p.allImports)
310-
break
311316
case *ast.Ident:
312317
f, _ := getFieldType(t, p.allImports)
313318
st := p.getOrCreateStruct(typeName)
314319
f = replacePackageConstant(f, st.PackageName)
315320
st.AddToComposition(f)
316-
break
317321
}
318322
}
319323
}
@@ -338,7 +342,7 @@ func (p *ClassParser) processSpec(spec ast.Spec) {
338342
switch c := v.Type.(type) {
339343
case *ast.StructType:
340344
declarationType = "class"
341-
handleGenDecStructType(p, typeName, c)
345+
handleGenDecStructType(p, typeName, c, v.TypeParams)
342346
case *ast.InterfaceType:
343347
declarationType = "interface"
344348
handleGenDecInterfaceType(p, typeName, c)
@@ -379,7 +383,6 @@ func (p *ClassParser) processSpec(spec ast.Spec) {
379383
p.allRenamedStructs[pack[0]][renamedClass] = pack[1]
380384
}
381385
}
382-
return
383386
}
384387

385388
// If this element is an array or a pointer, this function will return the type that is closer to these
@@ -465,7 +468,7 @@ func (p *ClassParser) renderStructures(pack string, structures map[string]*Struc
465468
str.WriteLineWithDepth(2, aliasComplexNameComment)
466469
str.WriteLineWithDepth(1, "}")
467470
}
468-
str.WriteLineWithDepth(0, fmt.Sprintf(`}`))
471+
str.WriteLineWithDepth(0, `}`)
469472
if p.renderingOptions.Compositions {
470473
str.WriteLineWithDepth(0, composition.String())
471474
}
@@ -479,7 +482,6 @@ func (p *ClassParser) renderStructures(pack string, structures map[string]*Struc
479482
}
480483

481484
func (p *ClassParser) renderAliases(str *LineStringBuilder) {
482-
483485
aliasString := ""
484486
if p.renderingOptions.ConnectionLabels {
485487
aliasString = aliasOf
@@ -505,7 +507,6 @@ func (p *ClassParser) renderAliases(str *LineStringBuilder) {
505507
}
506508

507509
func (p *ClassParser) renderStructure(structure *Struct, pack string, name string, str *LineStringBuilder, composition *LineStringBuilder, extends *LineStringBuilder, aggregations *LineStringBuilder) {
508-
509510
privateFields := &LineStringBuilder{}
510511
publicFields := &LineStringBuilder{}
511512
privateMethods := &LineStringBuilder{}
@@ -518,9 +519,24 @@ func (p *ClassParser) renderStructure(structure *Struct, pack string, name strin
518519
case "alias":
519520
sType = "<< (T, #FF7700) >> "
520521
renderStructureType = "class"
522+
}
521523

524+
types := ""
525+
if structure.Generics.exists() {
526+
types = "<"
527+
for t := range structure.Generics.Types {
528+
types += fmt.Sprintf("%s, ", t)
529+
}
530+
types = strings.TrimSuffix(types, ", ")
531+
types += " constrains "
532+
for _, n := range structure.Generics.Names {
533+
types += fmt.Sprintf("%s, ", n)
534+
}
535+
types = strings.TrimSuffix(types, ", ")
536+
types += ">"
522537
}
523-
str.WriteLineWithDepth(1, fmt.Sprintf(`%s %s %s {`, renderStructureType, name, sType))
538+
539+
str.WriteLineWithDepth(1, fmt.Sprintf(`%s %s%s %s {`, renderStructureType, name, types, sType))
524540
p.renderStructFields(structure, privateFields, publicFields)
525541
p.renderStructMethods(structure, privateMethods, publicMethods)
526542
p.renderCompositions(structure, name, composition)
@@ -538,7 +554,7 @@ func (p *ClassParser) renderStructure(structure *Struct, pack string, name strin
538554
if publicMethods.Len() > 0 {
539555
str.WriteLineWithDepth(0, publicMethods.String())
540556
}
541-
str.WriteLineWithDepth(1, fmt.Sprintf(`}`))
557+
str.WriteLineWithDepth(1, `}`)
542558
}
543559

544560
func (p *ClassParser) renderCompositions(structure *Struct, name string, composition *LineStringBuilder) {
@@ -562,7 +578,6 @@ func (p *ClassParser) renderCompositions(structure *Struct, name string, composi
562578
}
563579

564580
func (p *ClassParser) renderAggregations(structure *Struct, name string, aggregations *LineStringBuilder) {
565-
566581
aggregationMap := structure.Aggregations
567582
if p.renderingOptions.AggregatePrivateMembers {
568583
p.updatePrivateAggregations(structure, aggregationMap)
@@ -571,7 +586,6 @@ func (p *ClassParser) renderAggregations(structure *Struct, name string, aggrega
571586
}
572587

573588
func (p *ClassParser) updatePrivateAggregations(structure *Struct, aggregationsMap map[string]struct{}) {
574-
575589
for agg := range structure.PrivateAggregations {
576590
aggregationsMap[agg] = struct{}{}
577591
}
@@ -600,13 +614,13 @@ func (p *ClassParser) renderAggregationMap(aggregationMap map[string]struct{}, s
600614
}
601615

602616
func (p *ClassParser) getPackageName(t string, st *Struct) string {
603-
604617
packageName := st.PackageName
605618
if isPrimitiveString(t) {
606619
packageName = builtinPackageName
607620
}
608621
return packageName
609622
}
623+
610624
func (p *ClassParser) renderExtends(structure *Struct, name string, extends *LineStringBuilder) {
611625

612626
orderedExtends := []string{}
@@ -628,7 +642,6 @@ func (p *ClassParser) renderExtends(structure *Struct, name string, extends *Lin
628642
}
629643

630644
func (p *ClassParser) renderStructMethods(structure *Struct, privateMethods *LineStringBuilder, publicMethods *LineStringBuilder) {
631-
632645
for _, method := range structure.Functions {
633646
accessModifier := "+"
634647
if unicode.IsLower(rune(method.Name[0])) {
@@ -685,6 +698,7 @@ func (p *ClassParser) getOrCreateStruct(name string) *Struct {
685698
Functions: make([]*Function, 0),
686699
Fields: make([]*Field, 0),
687700
Type: "",
701+
Generics: NewGeneric(),
688702
Composition: make(map[string]struct{}, 0),
689703
Extends: make(map[string]struct{}, 0),
690704
Aggregations: make(map[string]struct{}, 0),

parser/class_parser_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package parser
22

33
import (
44
"go/ast"
5-
"io/ioutil"
5+
"os"
66
"reflect"
77
"testing"
88
)
@@ -94,6 +94,7 @@ func TestGetOrCreateStruct(t *testing.T) {
9494
Functions: make([]*Function, 0),
9595
Fields: make([]*Field, 0),
9696
Type: "",
97+
Generics: NewGeneric(),
9798
Composition: make(map[string]struct{}, 0),
9899
Extends: make(map[string]struct{}, 0),
99100
Aggregations: make(map[string]struct{}, 0),
@@ -181,7 +182,6 @@ func TestRenderStructFields(t *testing.T) {
181182
}
182183

183184
func TestRenderStructures(t *testing.T) {
184-
185185
structMap := map[string]*Struct{
186186
"MainClass": getTestStruct(),
187187
}
@@ -296,6 +296,7 @@ func getTestStruct() *Struct {
296296
ReturnValues: []string{"int"},
297297
},
298298
},
299+
Generics: NewGeneric(),
299300
}
300301
}
301302

@@ -563,7 +564,7 @@ func TestRender(t *testing.T) {
563564
})
564565

565566
resultRender := parser.Render()
566-
result, err := ioutil.ReadFile("../testingsupport/testingsupport.puml")
567+
result, err := os.ReadFile("../testingsupport/testingsupport.puml")
567568
if err != nil {
568569
t.Errorf("TestRender: expected no errors reading testing file, got %s", err.Error())
569570
}
@@ -592,7 +593,7 @@ func TestMultipleFolders(t *testing.T) {
592593
}
593594

594595
resultRender := parser.Render()
595-
result, err := ioutil.ReadFile("../testingsupport/subfolder1-2.puml")
596+
result, err := os.ReadFile("../testingsupport/subfolder1-2.puml")
596597
if err != nil {
597598
t.Errorf("TestMultipleFolders: expected no errors reading testing file, got %s", err.Error())
598599
}

parser/field.go

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ import (
99

1010
const packageConstant = "{packageName}"
1111

12-
//Field can hold the name and type of any field
12+
// Field can hold the name and type of any field
1313
type Field struct {
1414
Name string
1515
Type string
1616
FullType string
1717
}
1818

19-
//Returns a string representation of the given expression if it was recognized.
20-
//Refer to the implementation to see the different string representations.
19+
// Returns a string representation of the given expression if it was recognized.
20+
// Refer to the implementation to see the different string representations.
2121
func getFieldType(exp ast.Expr, aliases map[string]string) (string, []string) {
2222
switch v := exp.(type) {
2323
case *ast.Ident:
@@ -45,7 +45,6 @@ func getFieldType(exp ast.Expr, aliases map[string]string) (string, []string) {
4545
}
4646

4747
func getIdent(v *ast.Ident, aliases map[string]string) (string, []string) {
48-
4948
if isPrimitive(v) {
5049
return v.Name, []string{}
5150
}
@@ -59,7 +58,6 @@ func getArrayType(v *ast.ArrayType, aliases map[string]string) (string, []string
5958
}
6059

6160
func getSelectorExp(v *ast.SelectorExpr, aliases map[string]string) (string, []string) {
62-
6361
packageName := v.X.(*ast.Ident).Name
6462
if realPackageName, ok := aliases[packageName]; ok {
6563
packageName = realPackageName
@@ -69,26 +67,22 @@ func getSelectorExp(v *ast.SelectorExpr, aliases map[string]string) (string, []s
6967
}
7068

7169
func getMapType(v *ast.MapType, aliases map[string]string) (string, []string) {
72-
7370
t1, f1 := getFieldType(v.Key, aliases)
7471
t2, f2 := getFieldType(v.Value, aliases)
7572
return fmt.Sprintf("<font color=blue>map</font>[%s]%s", t1, t2), append(f1, f2...)
7673
}
7774

7875
func getStarExp(v *ast.StarExpr, aliases map[string]string) (string, []string) {
79-
8076
t, f := getFieldType(v.X, aliases)
8177
return fmt.Sprintf("*%s", t), f
8278
}
8379

8480
func getChanType(v *ast.ChanType, aliases map[string]string) (string, []string) {
85-
8681
t, f := getFieldType(v.Value, aliases)
8782
return fmt.Sprintf("<font color=blue>chan</font> %s", t), f
8883
}
8984

9085
func getStructType(v *ast.StructType, aliases map[string]string) (string, []string) {
91-
9286
fieldList := make([]string, 0)
9387
for _, field := range v.Fields.List {
9488
t, _ := getFieldType(field.Type, aliases)
@@ -98,7 +92,6 @@ func getStructType(v *ast.StructType, aliases map[string]string) (string, []stri
9892
}
9993

10094
func getInterfaceType(v *ast.InterfaceType, aliases map[string]string) (string, []string) {
101-
10295
methods := make([]string, 0)
10396
for _, field := range v.Methods.List {
10497
methodName := ""
@@ -112,17 +105,14 @@ func getInterfaceType(v *ast.InterfaceType, aliases map[string]string) (string,
112105
}
113106

114107
func getFuncType(v *ast.FuncType, aliases map[string]string) (string, []string) {
115-
116108
function := getFunction(v, "", aliases, "")
117109
params := make([]string, 0)
118110
for _, pa := range function.Parameters {
119111
params = append(params, pa.Type)
120112
}
121113
returns := ""
122114
returnList := make([]string, 0)
123-
for _, re := range function.ReturnValues {
124-
returnList = append(returnList, re)
125-
}
115+
returnList = append(returnList, function.ReturnValues...)
126116
if len(returnList) > 1 {
127117
returns = fmt.Sprintf("(%s)", strings.Join(returnList, ", "))
128118
} else {

0 commit comments

Comments
 (0)