Skip to content

Commit bcda3d1

Browse files
Started adding tests for ConcatenateFiles
1 parent eb95f4d commit bcda3d1

File tree

2 files changed

+100
-6
lines changed

2 files changed

+100
-6
lines changed

compiler/astutil/astutil.go

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -453,15 +453,24 @@ func FinalizeRemovals(file *ast.File) {
453453

454454
file.Imports = squeeze(file.Imports)
455455

456+
rebuildFileComments(file)
457+
}
458+
459+
// rebuildFileComments will rebuild the top level comments
460+
// group for the given file using the comments found in the file.
461+
func rebuildFileComments(file *ast.File) {
456462
file.Comments = nil // clear this first so ast.Inspect doesn't walk it.
457-
remComments := []*ast.CommentGroup{}
463+
comments := []*ast.CommentGroup{}
464+
if file.Doc != nil {
465+
comments = append(comments, file.Doc)
466+
}
458467
ast.Inspect(file, func(n ast.Node) bool {
459468
if cg, ok := n.(*ast.CommentGroup); ok {
460-
remComments = append(remComments, cg)
469+
comments = append(comments, cg)
461470
}
462471
return true
463472
})
464-
file.Comments = remComments
473+
file.Comments = comments
465474
}
466475

467476
// ConcatenateFiles will concatenate the given tailing files onto the
@@ -488,7 +497,11 @@ func ConcatenateFiles(file *ast.File, tails ...*ast.File) error {
488497
if file.Name.Name != tail.Name.Name {
489498
return fmt.Errorf("can not concatenate files with different package names: %q != %q", file.Name.Name, tail.Name.Name)
490499
}
500+
if file.GoVersion != tail.GoVersion {
501+
return fmt.Errorf("can not concatenate files with different Go versions: %q != %q", file.GoVersion, tail.GoVersion)
502+
}
491503

504+
// Concatenate the imports.
492505
for _, imp := range tail.Imports {
493506
path := imp.Path.Value
494507
if oldImp, ok := imports[path]; ok {
@@ -503,15 +516,33 @@ func ConcatenateFiles(file *ast.File, tails ...*ast.File) error {
503516
imports[imp.Path.Value] = imp
504517
}
505518

519+
// Concatenate the declarations.
506520
file.Decls = append(file.Decls, tail.Decls...)
507-
file.Comments = append(file.Comments, tail.Comments...)
521+
522+
// Concatenate the document comments.
523+
if tail.Doc != nil {
524+
if file.Doc == nil {
525+
file.Doc = tail.Doc
526+
} else {
527+
file.Doc.List = append(file.Doc.List, tail.Doc.List...)
528+
}
529+
}
530+
531+
// Concatenate the unresolved identifier and update scope.
532+
// Both of these are deprecated. See Object.
533+
// We just join them to attempt to keep the file in a valid state.
508534
file.Unresolved = append(file.Unresolved, tail.Unresolved...)
509535
for name, obj := range tail.Scope.Objects {
510536
if _, ok := file.Scope.Objects[name]; ok {
511537
return fmt.Errorf("can not concatenate files with duplicate object names: %q", name)
512538
}
513539
file.Scope.Objects[name] = obj
514540
}
541+
542+
// Update the file end to the new end.
543+
file.FileEnd = tail.FileEnd
515544
}
545+
546+
rebuildFileComments(file)
516547
return nil
517548
}

compiler/astutil/astutil_test.go

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -770,13 +770,13 @@ func TestFinalizeRemovals(t *testing.T) {
770770
t.Run(test.name, func(t *testing.T) {
771771
st := srctesting.New(t)
772772

773-
srcFile := st.Parse("testSrc.go", test.src)
773+
srcFile := st.Parse(`testSrc.go`, test.src)
774774
test.perforator(srcFile)
775775
FinalizeRemovals(srcFile)
776776
got := srctesting.Format(t, st.FileSet, srcFile)
777777

778778
// parse and format the expected result so that formatting matches
779-
wantFile := st.Parse("testWant.go", test.want)
779+
wantFile := st.Parse(`testWant.go`, test.want)
780780
want := srctesting.Format(t, st.FileSet, wantFile)
781781

782782
if got != want {
@@ -785,3 +785,66 @@ func TestFinalizeRemovals(t *testing.T) {
785785
})
786786
}
787787
}
788+
789+
func TestConcatenateFiles(t *testing.T) {
790+
tests := []struct {
791+
name string
792+
srcHead string
793+
srcTail string
794+
want string
795+
expErr string
796+
}{
797+
{
798+
name: `add a method`,
799+
srcHead: `package testpackage
800+
// foo is an original method.
801+
func foo() {}`,
802+
srcTail: `package testpackage
803+
// bar is a concatenated method
804+
// from an additional override file.
805+
func bar() {}`,
806+
want: `package testpackage
807+
// foo is an original method.
808+
func foo() {}
809+
// bar is a concatenated method
810+
// from an additional override file.
811+
func bar() {}`,
812+
},
813+
}
814+
815+
for _, test := range tests {
816+
t.Run(test.name, func(t *testing.T) {
817+
st := srctesting.New(t)
818+
if (len(test.want) > 0) == (len(test.expErr) > 0) {
819+
t.Fatal(`One and only one of "want" and "expErr" must be set`)
820+
}
821+
822+
headFile := st.Parse(`testHead.go`, test.srcHead)
823+
tailFile := st.Parse(`testTail.go`, test.srcTail)
824+
err := ConcatenateFiles(headFile, tailFile)
825+
if err != nil {
826+
if len(test.expErr) == 0 {
827+
t.Errorf(`Expected an AST but got an error: %v`, err)
828+
} else if err.Error() != test.expErr {
829+
t.Errorf("Unexpected error:\n\tgot: %q\n\twant: %q", err.Error(), test.expErr)
830+
}
831+
return
832+
}
833+
834+
ast.Print(st.FileSet, headFile) // TODO: REMOVE
835+
836+
got := srctesting.Format(t, st.FileSet, headFile)
837+
if len(test.want) == 0 {
838+
t.Errorf("Expected an error but got AST:\n\tgot: %q\n\twant: %q", got, test.expErr)
839+
return
840+
}
841+
842+
// parse and format the expected result so that formatting matches
843+
wantFile := st.Parse("testWant.go", test.want)
844+
want := srctesting.Format(t, st.FileSet, wantFile)
845+
if got != want {
846+
t.Errorf("Unexpected resulting AST:\n\tgot: %q\n\twant: %q", got, want)
847+
}
848+
})
849+
}
850+
}

0 commit comments

Comments
 (0)