Skip to content

Commit 5572f2e

Browse files
VolodymyrBgfjl
andauthored
rlp/rlpgen: implement package renaming support (#31148)
This adds support for importing types from multiple identically-named packages. --------- Co-authored-by: Felix Lange <[email protected]>
1 parent 038ff76 commit 5572f2e

File tree

4 files changed

+173
-25
lines changed

4 files changed

+173
-25
lines changed

rlp/rlpgen/gen.go

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"sort"
2525

2626
"github.com/ethereum/go-ethereum/rlp/internal/rlpstruct"
27+
"golang.org/x/tools/go/packages"
2728
)
2829

2930
// buildContext keeps the data needed for make*Op.
@@ -96,14 +97,20 @@ func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type {
9697
// file and assigns unique names of temporary variables.
9798
type genContext struct {
9899
inPackage *types.Package
99-
imports map[string]struct{}
100+
imports map[string]genImportPackage
100101
tempCounter int
101102
}
102103

104+
type genImportPackage struct {
105+
alias string
106+
pkg *types.Package
107+
}
108+
103109
func newGenContext(inPackage *types.Package) *genContext {
104110
return &genContext{
105-
inPackage: inPackage,
106-
imports: make(map[string]struct{}),
111+
inPackage: inPackage,
112+
imports: make(map[string]genImportPackage),
113+
tempCounter: 0,
107114
}
108115
}
109116

@@ -117,32 +124,78 @@ func (ctx *genContext) resetTemp() {
117124
ctx.tempCounter = 0
118125
}
119126

120-
func (ctx *genContext) addImport(path string) {
121-
if path == ctx.inPackage.Path() {
122-
return // avoid importing the package that we're generating in.
127+
func (ctx *genContext) addImportPath(path string) {
128+
pkg, err := ctx.loadPackage(path)
129+
if err != nil {
130+
panic(fmt.Sprintf("can't load package %q: %v", path, err))
123131
}
124-
// TODO: renaming?
125-
ctx.imports[path] = struct{}{}
132+
ctx.addImport(pkg)
126133
}
127134

128-
// importsList returns all packages that need to be imported.
129-
func (ctx *genContext) importsList() []string {
130-
imp := make([]string, 0, len(ctx.imports))
131-
for k := range ctx.imports {
132-
imp = append(imp, k)
135+
func (ctx *genContext) addImport(pkg *types.Package) string {
136+
if pkg.Path() == ctx.inPackage.Path() {
137+
return "" // avoid importing the package that we're generating in
133138
}
134-
sort.Strings(imp)
135-
return imp
139+
if p, exists := ctx.imports[pkg.Path()]; exists {
140+
return p.alias
141+
}
142+
var (
143+
baseName = pkg.Name()
144+
alias = baseName
145+
counter = 1
146+
)
147+
// If the base name conflicts with an existing import, add a numeric suffix.
148+
for ctx.hasAlias(alias) {
149+
alias = fmt.Sprintf("%s%d", baseName, counter)
150+
counter++
151+
}
152+
ctx.imports[pkg.Path()] = genImportPackage{alias, pkg}
153+
return alias
154+
}
155+
156+
// hasAlias checks if an alias is already in use
157+
func (ctx *genContext) hasAlias(alias string) bool {
158+
for _, p := range ctx.imports {
159+
if p.alias == alias {
160+
return true
161+
}
162+
}
163+
return false
136164
}
137165

138-
// qualify is the types.Qualifier used for printing types.
166+
// loadPackage attempts to load package information
167+
func (ctx *genContext) loadPackage(path string) (*types.Package, error) {
168+
cfg := &packages.Config{Mode: packages.NeedName}
169+
pkgs, err := packages.Load(cfg, path)
170+
if err != nil {
171+
return nil, err
172+
}
173+
if len(pkgs) == 0 {
174+
return nil, fmt.Errorf("no package found for path %s", path)
175+
}
176+
return types.NewPackage(path, pkgs[0].Name), nil
177+
}
178+
179+
// qualify is the types.Qualifier used for printing types
139180
func (ctx *genContext) qualify(pkg *types.Package) string {
140181
if pkg.Path() == ctx.inPackage.Path() {
141182
return ""
142183
}
143-
ctx.addImport(pkg.Path())
144-
// TODO: renaming?
145-
return pkg.Name()
184+
return ctx.addImport(pkg)
185+
}
186+
187+
// importsList returns all packages that need to be imported
188+
func (ctx *genContext) importsList() []string {
189+
imp := make([]string, 0, len(ctx.imports))
190+
for path, p := range ctx.imports {
191+
if p.alias == p.pkg.Name() {
192+
imp = append(imp, fmt.Sprintf("%q", path))
193+
} else {
194+
imp = append(imp, fmt.Sprintf("%s %q", p.alias, path))
195+
}
196+
}
197+
sort.Strings(imp)
198+
return imp
146199
}
147200

148201
type op interface {
@@ -359,7 +412,7 @@ func (op uint256Op) genWrite(ctx *genContext, v string) string {
359412
}
360413

361414
func (op uint256Op) genDecode(ctx *genContext) (string, string) {
362-
ctx.addImport("github.com/holiman/uint256")
415+
ctx.addImportPath("github.com/holiman/uint256")
363416

364417
var b bytes.Buffer
365418
resultV := ctx.temp()
@@ -732,7 +785,7 @@ func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstru
732785
// generateDecoder generates the DecodeRLP method on 'typ'.
733786
func generateDecoder(ctx *genContext, typ string, op op) []byte {
734787
ctx.resetTemp()
735-
ctx.addImport(pathOfPackageRLP)
788+
ctx.addImportPath(pathOfPackageRLP)
736789

737790
result, code := op.genDecode(ctx)
738791
var b bytes.Buffer
@@ -747,8 +800,8 @@ func generateDecoder(ctx *genContext, typ string, op op) []byte {
747800
// generateEncoder generates the EncodeRLP method on 'typ'.
748801
func generateEncoder(ctx *genContext, typ string, op op) []byte {
749802
ctx.resetTemp()
750-
ctx.addImport("io")
751-
ctx.addImport(pathOfPackageRLP)
803+
ctx.addImportPath("io")
804+
ctx.addImportPath(pathOfPackageRLP)
752805

753806
var b bytes.Buffer
754807
fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ)
@@ -783,7 +836,7 @@ func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]b
783836
var b bytes.Buffer
784837
fmt.Fprintf(&b, "package %s\n\n", pkg.Name())
785838
for _, imp := range ctx.importsList() {
786-
fmt.Fprintf(&b, "import %q\n", imp)
839+
fmt.Fprintf(&b, "import %s\n", imp)
787840
}
788841
if encoder {
789842
fmt.Fprintln(&b)

rlp/rlpgen/gen_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func init() {
4747
}
4848
}
4949

50-
var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256"}
50+
var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256", "pkgclash"}
5151

5252
func TestOutput(t *testing.T) {
5353
for _, test := range tests {

rlp/rlpgen/testdata/pkgclash.in.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// -*- mode: go -*-
2+
3+
package test
4+
5+
import (
6+
eth1 "github.com/ethereum/go-ethereum/eth"
7+
eth2 "github.com/ethereum/go-ethereum/eth/protocols/eth"
8+
)
9+
10+
type Test struct {
11+
A eth1.MinerAPI
12+
B eth2.GetReceiptsPacket
13+
}

rlp/rlpgen/testdata/pkgclash.out.txt

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package test
2+
3+
import "github.com/ethereum/go-ethereum/common"
4+
import "github.com/ethereum/go-ethereum/eth"
5+
import "github.com/ethereum/go-ethereum/rlp"
6+
import "io"
7+
import eth1 "github.com/ethereum/go-ethereum/eth/protocols/eth"
8+
9+
func (obj *Test) EncodeRLP(_w io.Writer) error {
10+
w := rlp.NewEncoderBuffer(_w)
11+
_tmp0 := w.List()
12+
_tmp1 := w.List()
13+
w.ListEnd(_tmp1)
14+
_tmp2 := w.List()
15+
w.WriteUint64(obj.B.RequestId)
16+
_tmp3 := w.List()
17+
for _, _tmp4 := range obj.B.GetReceiptsRequest {
18+
w.WriteBytes(_tmp4[:])
19+
}
20+
w.ListEnd(_tmp3)
21+
w.ListEnd(_tmp2)
22+
w.ListEnd(_tmp0)
23+
return w.Flush()
24+
}
25+
26+
func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
27+
var _tmp0 Test
28+
{
29+
if _, err := dec.List(); err != nil {
30+
return err
31+
}
32+
// A:
33+
var _tmp1 eth.MinerAPI
34+
{
35+
if _, err := dec.List(); err != nil {
36+
return err
37+
}
38+
if err := dec.ListEnd(); err != nil {
39+
return err
40+
}
41+
}
42+
_tmp0.A = _tmp1
43+
// B:
44+
var _tmp2 eth1.GetReceiptsPacket
45+
{
46+
if _, err := dec.List(); err != nil {
47+
return err
48+
}
49+
// RequestId:
50+
_tmp3, err := dec.Uint64()
51+
if err != nil {
52+
return err
53+
}
54+
_tmp2.RequestId = _tmp3
55+
// GetReceiptsRequest:
56+
var _tmp4 []common.Hash
57+
if _, err := dec.List(); err != nil {
58+
return err
59+
}
60+
for dec.MoreDataInList() {
61+
var _tmp5 common.Hash
62+
if err := dec.ReadBytes(_tmp5[:]); err != nil {
63+
return err
64+
}
65+
_tmp4 = append(_tmp4, _tmp5)
66+
}
67+
if err := dec.ListEnd(); err != nil {
68+
return err
69+
}
70+
_tmp2.GetReceiptsRequest = _tmp4
71+
if err := dec.ListEnd(); err != nil {
72+
return err
73+
}
74+
}
75+
_tmp0.B = _tmp2
76+
if err := dec.ListEnd(); err != nil {
77+
return err
78+
}
79+
}
80+
*obj = _tmp0
81+
return nil
82+
}

0 commit comments

Comments
 (0)