@@ -589,6 +589,124 @@ func TestSqueezeIdents(t *testing.T) {
589589 }
590590}
591591
592+ func TestPruneImports (t * testing.T ) {
593+ tests := []struct {
594+ name string
595+ src string
596+ want string
597+ }{
598+ {
599+ name : `no imports` ,
600+ src : `package testpackage
601+ func foo() {}` ,
602+ want : `package testpackage
603+ func foo() {}` ,
604+ }, {
605+ name : `keep used imports` ,
606+ src : `package testpackage
607+ import "fmt"
608+ func foo() { fmt.Println("foo") }` ,
609+ want : `package testpackage
610+ import "fmt"
611+ func foo() { fmt.Println("foo") }` ,
612+ }, {
613+ name : `remove imports that are not used` ,
614+ src : `package testpackage
615+ import "fmt"
616+ func foo() { }` ,
617+ want : `package testpackage
618+ func foo() { }` ,
619+ }, {
620+ name : `remove imports that are unused but masked by an object` ,
621+ src : `package testpackage
622+ import "fmt"
623+ var fmt = "format"
624+ func foo() string { return fmt }` ,
625+ want : `package testpackage
626+ var fmt = "format"
627+ func foo() string { return fmt }` ,
628+ }, {
629+ name : `remove imports from empty file` ,
630+ src : `package testpackage
631+ import "fmt"
632+ import _ "unsafe"` ,
633+ want : `package testpackage` ,
634+ }, {
635+ name : `remove imports from empty file except for unsafe when linking` ,
636+ src : `package testpackage
637+ import "fmt"
638+ import "embed"
639+
640+ //go:linkname foo runtime.foo
641+ import "unsafe"` ,
642+ want : `package testpackage
643+
644+ //go:linkname foo runtime.foo
645+ import _ "unsafe"` ,
646+ }, {
647+ name : `keep embed imports when embedding` ,
648+ src : `package testpackage
649+ import "fmt"
650+ import "embed"
651+ import "unsafe"
652+
653+ //go:embed "foo.txt"
654+ var foo string` ,
655+ want : `package testpackage
656+ import _ "embed"
657+
658+ //go:embed "foo.txt"
659+ var foo string` ,
660+ }, {
661+ name : `keep imports that just needed an underscore` ,
662+ src : `package testpackage
663+ import "embed"
664+ //go:linkname foo runtime.foo
665+ import "unsafe"
666+ //go:embed "foo.txt"
667+ var foo string` ,
668+ want : `package testpackage
669+ import _ "embed"
670+ //go:linkname foo runtime.foo
671+ import _ "unsafe"
672+ //go:embed "foo.txt"
673+ var foo string` ,
674+ }, {
675+ name : `keep imports without names` ,
676+ src : `package testpackage
677+ import _ "fmt"
678+ import "log"
679+ import . "math"
680+
681+ var foo string` ,
682+ want : `package testpackage
683+ import _ "fmt"
684+
685+ import . "math"
686+
687+ var foo string` ,
688+ },
689+ }
690+
691+ for _ , test := range tests {
692+ t .Run (test .name , func (t * testing.T ) {
693+ st := srctesting .New (t )
694+
695+ srcFile := st .Parse (`testSrc.go` , test .src )
696+ PruneImports (srcFile )
697+ got := srctesting .Format (t , st .FileSet , srcFile )
698+
699+ // parse and format the expected result so that formatting matches
700+ wantFile := st .Parse (`testWant.go` , test .want )
701+ want := srctesting .Format (t , st .FileSet , wantFile )
702+
703+ if got != want {
704+ t .Errorf ("Unexpected resulting AST after PruneImports:\n \t got: %q\n \t want: %q" , got , want )
705+ }
706+ })
707+ }
708+ }
709+
592710func TestFinalizeRemovals (t * testing.T ) {
593711 tests := []struct {
594712 name string
0 commit comments