@@ -24,6 +24,7 @@ import (
24
24
"sort"
25
25
26
26
"github.com/ethereum/go-ethereum/rlp/internal/rlpstruct"
27
+ "golang.org/x/tools/go/packages"
27
28
)
28
29
29
30
// buildContext keeps the data needed for make*Op.
@@ -96,14 +97,20 @@ func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type {
96
97
// file and assigns unique names of temporary variables.
97
98
type genContext struct {
98
99
inPackage * types.Package
99
- imports map [string ]struct {}
100
+ imports map [string ]genImportPackage
100
101
tempCounter int
101
102
}
102
103
104
+ type genImportPackage struct {
105
+ alias string
106
+ pkg * types.Package
107
+ }
108
+
103
109
func newGenContext (inPackage * types.Package ) * genContext {
104
110
return & genContext {
105
- inPackage : inPackage ,
106
- imports : make (map [string ]struct {}),
111
+ inPackage : inPackage ,
112
+ imports : make (map [string ]genImportPackage ),
113
+ tempCounter : 0 ,
107
114
}
108
115
}
109
116
@@ -117,32 +124,78 @@ func (ctx *genContext) resetTemp() {
117
124
ctx .tempCounter = 0
118
125
}
119
126
120
- func (ctx * genContext ) addImport (path string ) {
121
- if path == ctx .inPackage .Path () {
122
- return // avoid importing the package that we're generating in.
127
+ func (ctx * genContext ) addImportPath (path string ) {
128
+ pkg , err := ctx .loadPackage (path )
129
+ if err != nil {
130
+ panic (fmt .Sprintf ("can't load package %q: %v" , path , err ))
123
131
}
124
- // TODO: renaming?
125
- ctx .imports [path ] = struct {}{}
132
+ ctx .addImport (pkg )
126
133
}
127
134
128
- // importsList returns all packages that need to be imported.
129
- func (ctx * genContext ) importsList () []string {
130
- imp := make ([]string , 0 , len (ctx .imports ))
131
- for k := range ctx .imports {
132
- imp = append (imp , k )
135
+ func (ctx * genContext ) addImport (pkg * types.Package ) string {
136
+ if pkg .Path () == ctx .inPackage .Path () {
137
+ return "" // avoid importing the package that we're generating in
133
138
}
134
- sort .Strings (imp )
135
- return imp
139
+ if p , exists := ctx .imports [pkg .Path ()]; exists {
140
+ return p .alias
141
+ }
142
+ var (
143
+ baseName = pkg .Name ()
144
+ alias = baseName
145
+ counter = 1
146
+ )
147
+ // If the base name conflicts with an existing import, add a numeric suffix.
148
+ for ctx .hasAlias (alias ) {
149
+ alias = fmt .Sprintf ("%s%d" , baseName , counter )
150
+ counter ++
151
+ }
152
+ ctx .imports [pkg .Path ()] = genImportPackage {alias , pkg }
153
+ return alias
154
+ }
155
+
156
+ // hasAlias checks if an alias is already in use
157
+ func (ctx * genContext ) hasAlias (alias string ) bool {
158
+ for _ , p := range ctx .imports {
159
+ if p .alias == alias {
160
+ return true
161
+ }
162
+ }
163
+ return false
136
164
}
137
165
138
- // qualify is the types.Qualifier used for printing types.
166
+ // loadPackage attempts to load package information
167
+ func (ctx * genContext ) loadPackage (path string ) (* types.Package , error ) {
168
+ cfg := & packages.Config {Mode : packages .NeedName }
169
+ pkgs , err := packages .Load (cfg , path )
170
+ if err != nil {
171
+ return nil , err
172
+ }
173
+ if len (pkgs ) == 0 {
174
+ return nil , fmt .Errorf ("no package found for path %s" , path )
175
+ }
176
+ return types .NewPackage (path , pkgs [0 ].Name ), nil
177
+ }
178
+
179
+ // qualify is the types.Qualifier used for printing types
139
180
func (ctx * genContext ) qualify (pkg * types.Package ) string {
140
181
if pkg .Path () == ctx .inPackage .Path () {
141
182
return ""
142
183
}
143
- ctx .addImport (pkg .Path ())
144
- // TODO: renaming?
145
- return pkg .Name ()
184
+ return ctx .addImport (pkg )
185
+ }
186
+
187
+ // importsList returns all packages that need to be imported
188
+ func (ctx * genContext ) importsList () []string {
189
+ imp := make ([]string , 0 , len (ctx .imports ))
190
+ for path , p := range ctx .imports {
191
+ if p .alias == p .pkg .Name () {
192
+ imp = append (imp , fmt .Sprintf ("%q" , path ))
193
+ } else {
194
+ imp = append (imp , fmt .Sprintf ("%s %q" , p .alias , path ))
195
+ }
196
+ }
197
+ sort .Strings (imp )
198
+ return imp
146
199
}
147
200
148
201
type op interface {
@@ -359,7 +412,7 @@ func (op uint256Op) genWrite(ctx *genContext, v string) string {
359
412
}
360
413
361
414
func (op uint256Op ) genDecode (ctx * genContext ) (string , string ) {
362
- ctx .addImport ("github.com/holiman/uint256" )
415
+ ctx .addImportPath ("github.com/holiman/uint256" )
363
416
364
417
var b bytes.Buffer
365
418
resultV := ctx .temp ()
@@ -732,7 +785,7 @@ func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstru
732
785
// generateDecoder generates the DecodeRLP method on 'typ'.
733
786
func generateDecoder (ctx * genContext , typ string , op op ) []byte {
734
787
ctx .resetTemp ()
735
- ctx .addImport (pathOfPackageRLP )
788
+ ctx .addImportPath (pathOfPackageRLP )
736
789
737
790
result , code := op .genDecode (ctx )
738
791
var b bytes.Buffer
@@ -747,8 +800,8 @@ func generateDecoder(ctx *genContext, typ string, op op) []byte {
747
800
// generateEncoder generates the EncodeRLP method on 'typ'.
748
801
func generateEncoder (ctx * genContext , typ string , op op ) []byte {
749
802
ctx .resetTemp ()
750
- ctx .addImport ("io" )
751
- ctx .addImport (pathOfPackageRLP )
803
+ ctx .addImportPath ("io" )
804
+ ctx .addImportPath (pathOfPackageRLP )
752
805
753
806
var b bytes.Buffer
754
807
fmt .Fprintf (& b , "func (obj *%s) EncodeRLP(_w io.Writer) error {\n " , typ )
@@ -783,7 +836,7 @@ func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]b
783
836
var b bytes.Buffer
784
837
fmt .Fprintf (& b , "package %s\n \n " , pkg .Name ())
785
838
for _ , imp := range ctx .importsList () {
786
- fmt .Fprintf (& b , "import %q \n " , imp )
839
+ fmt .Fprintf (& b , "import %s \n " , imp )
787
840
}
788
841
if encoder {
789
842
fmt .Fprintln (& b )
0 commit comments