Skip to content

Commit 6dcc746

Browse files
Added more tests
1 parent 9fbef6a commit 6dcc746

File tree

2 files changed

+128
-9
lines changed

2 files changed

+128
-9
lines changed

compiler/astutil/astutil.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,14 @@ func hasDirective(node ast.Node, directiveAction string) bool {
224224
// HasDirectivePrefix determines if any line in the given file
225225
// has the given directive prefix in it.
226226
func HasDirectivePrefix(file *ast.File, prefix string) bool {
227-
for _, cg := range file.Comments {
228-
for _, c := range cg.List {
229-
if strings.HasPrefix(c.Text, prefix) {
230-
return true
231-
}
227+
foundDirective := false
228+
ast.Inspect(file, func(n ast.Node) bool {
229+
if c, ok := n.(*ast.Comment); ok && strings.HasPrefix(c.Text, prefix) {
230+
foundDirective = true
232231
}
233-
}
234-
return false
232+
return !foundDirective
233+
})
234+
return foundDirective
235235
}
236236

237237
// FindLoopStmt tries to find the loop statement among the AST nodes in the
@@ -333,7 +333,7 @@ func PruneImports(file *ast.File) {
333333
}
334334
}
335335

336-
// Remove "unused imports" for any import which is used.
336+
// Remove from "unused imports" for any import which is used.
337337
ast.Inspect(file, func(n ast.Node) bool {
338338
if sel, ok := n.(*ast.SelectorExpr); ok {
339339
if id, ok := sel.X.(*ast.Ident); ok && id.Obj == nil {
@@ -346,7 +346,7 @@ func PruneImports(file *ast.File) {
346346
return
347347
}
348348

349-
// Remove "unused imports" for any import used for a directive.
349+
// Remove from "unused imports" for any import used for a directive.
350350
directiveImports := map[string]string{
351351
`unsafe`: `//go:linkname `,
352352
`embed`: `//go:embed `,
@@ -409,6 +409,7 @@ func squeeze[E ast.Node, S ~[]E](s S) S {
409409
// FinalizeRemovals fully removes any declaration, specification, imports
410410
// that have been set to nil. This will also remove any unassociated comment
411411
// groups, including the comments from removed code.
412+
// Comments that are floating and tied to a node will be lost.
412413
func FinalizeRemovals(file *ast.File) {
413414
fileChanged := false
414415
for i, decl := range file.Decls {

compiler/astutil/astutil_test.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,124 @@ func TestSqueezeIdents(t *testing.T) {
589589
}
590590
}
591591

592+
func TestPruneImports(t *testing.T) {
593+
tests := []struct {
594+
name string
595+
src string
596+
want string
597+
}{
598+
{
599+
name: `no imports`,
600+
src: `package testpackage
601+
func foo() {}`,
602+
want: `package testpackage
603+
func foo() {}`,
604+
}, {
605+
name: `keep used imports`,
606+
src: `package testpackage
607+
import "fmt"
608+
func foo() { fmt.Println("foo") }`,
609+
want: `package testpackage
610+
import "fmt"
611+
func foo() { fmt.Println("foo") }`,
612+
}, {
613+
name: `remove imports that are not used`,
614+
src: `package testpackage
615+
import "fmt"
616+
func foo() { }`,
617+
want: `package testpackage
618+
func foo() { }`,
619+
}, {
620+
name: `remove imports that are unused but masked by an object`,
621+
src: `package testpackage
622+
import "fmt"
623+
var fmt = "format"
624+
func foo() string { return fmt }`,
625+
want: `package testpackage
626+
var fmt = "format"
627+
func foo() string { return fmt }`,
628+
}, {
629+
name: `remove imports from empty file`,
630+
src: `package testpackage
631+
import "fmt"
632+
import _ "unsafe"`,
633+
want: `package testpackage`,
634+
}, {
635+
name: `remove imports from empty file except for unsafe when linking`,
636+
src: `package testpackage
637+
import "fmt"
638+
import "embed"
639+
640+
//go:linkname foo runtime.foo
641+
import "unsafe"`,
642+
want: `package testpackage
643+
644+
//go:linkname foo runtime.foo
645+
import _ "unsafe"`,
646+
}, {
647+
name: `keep embed imports when embedding`,
648+
src: `package testpackage
649+
import "fmt"
650+
import "embed"
651+
import "unsafe"
652+
653+
//go:embed "foo.txt"
654+
var foo string`,
655+
want: `package testpackage
656+
import _ "embed"
657+
658+
//go:embed "foo.txt"
659+
var foo string`,
660+
}, {
661+
name: `keep imports that just needed an underscore`,
662+
src: `package testpackage
663+
import "embed"
664+
//go:linkname foo runtime.foo
665+
import "unsafe"
666+
//go:embed "foo.txt"
667+
var foo string`,
668+
want: `package testpackage
669+
import _ "embed"
670+
//go:linkname foo runtime.foo
671+
import _ "unsafe"
672+
//go:embed "foo.txt"
673+
var foo string`,
674+
}, {
675+
name: `keep imports without names`,
676+
src: `package testpackage
677+
import _ "fmt"
678+
import "log"
679+
import . "math"
680+
681+
var foo string`,
682+
want: `package testpackage
683+
import _ "fmt"
684+
685+
import . "math"
686+
687+
var foo string`,
688+
},
689+
}
690+
691+
for _, test := range tests {
692+
t.Run(test.name, func(t *testing.T) {
693+
st := srctesting.New(t)
694+
695+
srcFile := st.Parse(`testSrc.go`, test.src)
696+
PruneImports(srcFile)
697+
got := srctesting.Format(t, st.FileSet, srcFile)
698+
699+
// parse and format the expected result so that formatting matches
700+
wantFile := st.Parse(`testWant.go`, test.want)
701+
want := srctesting.Format(t, st.FileSet, wantFile)
702+
703+
if got != want {
704+
t.Errorf("Unexpected resulting AST after PruneImports:\n\tgot: %q\n\twant: %q", got, want)
705+
}
706+
})
707+
}
708+
}
709+
592710
func TestFinalizeRemovals(t *testing.T) {
593711
tests := []struct {
594712
name string

0 commit comments

Comments
 (0)