@@ -24,6 +24,7 @@ import (
2424 "sort"
2525
2626 "github.com/ethereum/go-ethereum/rlp/internal/rlpstruct"
27+ "golang.org/x/tools/go/packages"
2728)
2829
2930// buildContext keeps the data needed for make*Op.
@@ -96,14 +97,20 @@ func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type {
9697// file and assigns unique names of temporary variables.
9798type genContext struct {
9899 inPackage * types.Package
99- imports map [string ]struct {}
100+ imports map [string ]genImportPackage
100101 tempCounter int
101102}
102103
104+ type genImportPackage struct {
105+ alias string
106+ pkg * types.Package
107+ }
108+
103109func newGenContext (inPackage * types.Package ) * genContext {
104110 return & genContext {
105- inPackage : inPackage ,
106- imports : make (map [string ]struct {}),
111+ inPackage : inPackage ,
112+ imports : make (map [string ]genImportPackage ),
113+ tempCounter : 0 ,
107114 }
108115}
109116
@@ -117,32 +124,78 @@ func (ctx *genContext) resetTemp() {
117124 ctx .tempCounter = 0
118125}
119126
120- func (ctx * genContext ) addImport (path string ) {
121- if path == ctx .inPackage .Path () {
122- return // avoid importing the package that we're generating in.
127+ func (ctx * genContext ) addImportPath (path string ) {
128+ pkg , err := ctx .loadPackage (path )
129+ if err != nil {
130+ panic (fmt .Sprintf ("can't load package %q: %v" , path , err ))
123131 }
124- // TODO: renaming?
125- ctx .imports [path ] = struct {}{}
132+ ctx .addImport (pkg )
126133}
127134
128- // importsList returns all packages that need to be imported.
129- func (ctx * genContext ) importsList () []string {
130- imp := make ([]string , 0 , len (ctx .imports ))
131- for k := range ctx .imports {
132- imp = append (imp , k )
135+ func (ctx * genContext ) addImport (pkg * types.Package ) string {
136+ if pkg .Path () == ctx .inPackage .Path () {
137+ return "" // avoid importing the package that we're generating in
133138 }
134- sort .Strings (imp )
135- return imp
139+ if p , exists := ctx .imports [pkg .Path ()]; exists {
140+ return p .alias
141+ }
142+ var (
143+ baseName = pkg .Name ()
144+ alias = baseName
145+ counter = 1
146+ )
147+ // If the base name conflicts with an existing import, add a numeric suffix.
148+ for ctx .hasAlias (alias ) {
149+ alias = fmt .Sprintf ("%s%d" , baseName , counter )
150+ counter ++
151+ }
152+ ctx .imports [pkg .Path ()] = genImportPackage {alias , pkg }
153+ return alias
154+ }
155+
156+ // hasAlias checks if an alias is already in use
157+ func (ctx * genContext ) hasAlias (alias string ) bool {
158+ for _ , p := range ctx .imports {
159+ if p .alias == alias {
160+ return true
161+ }
162+ }
163+ return false
136164}
137165
138- // qualify is the types.Qualifier used for printing types.
166+ // loadPackage attempts to load package information
167+ func (ctx * genContext ) loadPackage (path string ) (* types.Package , error ) {
168+ cfg := & packages.Config {Mode : packages .NeedName }
169+ pkgs , err := packages .Load (cfg , path )
170+ if err != nil {
171+ return nil , err
172+ }
173+ if len (pkgs ) == 0 {
174+ return nil , fmt .Errorf ("no package found for path %s" , path )
175+ }
176+ return types .NewPackage (path , pkgs [0 ].Name ), nil
177+ }
178+
179+ // qualify is the types.Qualifier used for printing types
139180func (ctx * genContext ) qualify (pkg * types.Package ) string {
140181 if pkg .Path () == ctx .inPackage .Path () {
141182 return ""
142183 }
143- ctx .addImport (pkg .Path ())
144- // TODO: renaming?
145- return pkg .Name ()
184+ return ctx .addImport (pkg )
185+ }
186+
187+ // importsList returns all packages that need to be imported
188+ func (ctx * genContext ) importsList () []string {
189+ imp := make ([]string , 0 , len (ctx .imports ))
190+ for path , p := range ctx .imports {
191+ if p .alias == p .pkg .Name () {
192+ imp = append (imp , fmt .Sprintf ("%q" , path ))
193+ } else {
194+ imp = append (imp , fmt .Sprintf ("%s %q" , p .alias , path ))
195+ }
196+ }
197+ sort .Strings (imp )
198+ return imp
146199}
147200
148201type op interface {
@@ -359,7 +412,7 @@ func (op uint256Op) genWrite(ctx *genContext, v string) string {
359412}
360413
361414func (op uint256Op ) genDecode (ctx * genContext ) (string , string ) {
362- ctx .addImport ("github.com/holiman/uint256" )
415+ ctx .addImportPath ("github.com/holiman/uint256" )
363416
364417 var b bytes.Buffer
365418 resultV := ctx .temp ()
@@ -732,7 +785,7 @@ func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstru
732785// generateDecoder generates the DecodeRLP method on 'typ'.
733786func generateDecoder (ctx * genContext , typ string , op op ) []byte {
734787 ctx .resetTemp ()
735- ctx .addImport (pathOfPackageRLP )
788+ ctx .addImportPath (pathOfPackageRLP )
736789
737790 result , code := op .genDecode (ctx )
738791 var b bytes.Buffer
@@ -747,8 +800,8 @@ func generateDecoder(ctx *genContext, typ string, op op) []byte {
747800// generateEncoder generates the EncodeRLP method on 'typ'.
748801func generateEncoder (ctx * genContext , typ string , op op ) []byte {
749802 ctx .resetTemp ()
750- ctx .addImport ("io" )
751- ctx .addImport (pathOfPackageRLP )
803+ ctx .addImportPath ("io" )
804+ ctx .addImportPath (pathOfPackageRLP )
752805
753806 var b bytes.Buffer
754807 fmt .Fprintf (& b , "func (obj *%s) EncodeRLP(_w io.Writer) error {\n " , typ )
@@ -783,7 +836,7 @@ func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]b
783836 var b bytes.Buffer
784837 fmt .Fprintf (& b , "package %s\n \n " , pkg .Name ())
785838 for _ , imp := range ctx .importsList () {
786- fmt .Fprintf (& b , "import %q \n " , imp )
839+ fmt .Fprintf (& b , "import %s \n " , imp )
787840 }
788841 if encoder {
789842 fmt .Fprintln (& b )
0 commit comments