Skip to content

Commit ce9dc33

Browse files
authored
fix(pkg/userop): fix gas multiplication (#461)
* fix(pkg/userop): fix gas multiplication * refactor(pkg/userop): extract gas price provider, make mocked one * test(pkg/userop): skip get polygon gas prices test
1 parent 8c4fa67 commit ce9dc33

File tree

5 files changed

+246
-119
lines changed

5 files changed

+246
-119
lines changed

pkg/userop/gas.go

Lines changed: 8 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,85 +2,28 @@ package userop
22

33
import (
44
"context"
5-
"encoding/json"
65
"fmt"
7-
"io"
86
"math/big"
9-
"net/http"
107
"strconv"
118

129
"github.com/ethereum/go-ethereum/common/hexutil"
1310
"github.com/shopspring/decimal"
1411
)
1512

16-
func getPolygonGasPrices(chainId *big.Int) (*big.Int, *big.Int, error) {
17-
var resp *http.Response
18-
var err error
13+
func getGasPricesAndApplyMultipliers(ctx context.Context, provider GasPriceProvider, gasConfig GasConfig) (maxFeePerGas, maxPriorityFeePerGas *big.Int, err error) {
14+
logger.Debug("getting gas prices")
1915

20-
if chainId == nil {
21-
return nil, nil, fmt.Errorf("chain ID is nil")
22-
}
23-
24-
switch {
25-
case chainId.Uint64() == 137:
26-
resp, err = http.Get("https://gasstation.polygon.technology/v2")
27-
case chainId.Uint64() == 80002:
28-
resp, err = http.Get("https://gasstation.polygon.technology/amoy")
29-
default:
30-
return nil, nil, fmt.Errorf("unsupported chain ID: %v", chainId)
31-
}
32-
33-
if err != nil {
34-
return nil, nil, fmt.Errorf("error fetching data: %v", err)
35-
}
36-
defer resp.Body.Close()
37-
38-
body, err := io.ReadAll(resp.Body)
16+
maxFeePerGas, maxPriorityFeePerGas, err = provider.GetGasPrices(ctx)
3917
if err != nil {
40-
return nil, nil, fmt.Errorf("error reading response body: %v", err)
41-
}
42-
43-
var gasData struct {
44-
Fast struct {
45-
MaxPriorityFee decimal.Decimal `json:"maxPriorityFee"`
46-
MaxFee decimal.Decimal `json:"maxFee"`
47-
} `json:"fast"`
18+
return nil, nil, fmt.Errorf("failed to get gas prices: %w", err)
4819
}
4920

50-
err = json.Unmarshal(body, &gasData)
51-
if err != nil {
52-
return nil, nil, fmt.Errorf("error unmarshalling JSON: %v", err)
53-
}
21+
logger.Debug("fetched gas price", "maxFeePerGas", maxFeePerGas, "maxPriorityFeePerGas", maxPriorityFeePerGas)
5422

55-
gweiMult := decimal.NewFromInt(1e9)
56-
57-
maxFeePerGas := gasData.Fast.MaxFee.Mul(gweiMult).BigInt()
58-
maxPriorityFeePerGas := gasData.Fast.MaxPriorityFee.Mul(gweiMult).BigInt()
59-
60-
return maxFeePerGas, maxPriorityFeePerGas, nil
61-
}
62-
63-
func getGasPrices(ctx context.Context, provider EthBackend) (*big.Int, *big.Int, error) {
64-
var maxPriorityFeePerGasStr string
65-
if err := provider.RPC().CallContext(ctx, &maxPriorityFeePerGasStr, "eth_maxPriorityFeePerGas"); err != nil {
66-
return nil, nil, err
67-
}
68-
69-
maxPriorityFeePerGas, ok := new(big.Int).SetString(maxPriorityFeePerGasStr, 0)
70-
if !ok {
71-
return nil, nil, fmt.Errorf("failed to parse maxPriorityFeePerGas: %s", maxPriorityFeePerGasStr)
72-
}
73-
logger.Debug("fetched maxPriorityFeePerGas", "maxPriorityFeePerGas", maxPriorityFeePerGas.String())
74-
75-
// Get the latest block to read its base fee
76-
block, err := provider.BlockByNumber(ctx, nil)
77-
if err != nil {
78-
return nil, nil, err
79-
}
80-
blockBaseFee := block.BaseFee()
81-
logger.Debug("fetched block base fee", "baseFee", blockBaseFee.String())
23+
maxFeePerGas = decimal.NewFromBigInt(maxFeePerGas, 0).Mul(gasConfig.MaxFeePerGasMultiplier).BigInt()
24+
maxPriorityFeePerGas = decimal.NewFromBigInt(maxPriorityFeePerGas, 0).Mul(gasConfig.MaxPriorityFeePerGasMultiplier).BigInt()
8225

83-
maxFeePerGas := blockBaseFee.Add(blockBaseFee, maxPriorityFeePerGas)
26+
logger.Debug("calculated gas price", "maxFeePerGas", maxFeePerGas, "maxPriorityFeePerGas", maxPriorityFeePerGas)
8427

8528
return maxFeePerGas, maxPriorityFeePerGas, nil
8629
}

pkg/userop/gas_price_provider.go

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package userop
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"math/big"
9+
"net/http"
10+
11+
"github.com/shopspring/decimal"
12+
)
13+
14+
type GasPriceProvider interface {
15+
GetGasPrices(ctx context.Context) (maxFeePerGas, maxPriorityFeePerGas *big.Int, err error)
16+
}
17+
18+
type EVMGasPriceProvider struct {
19+
provider EthBackend
20+
}
21+
22+
func NewEVMGasPriceProvider(provider EthBackend) *EVMGasPriceProvider {
23+
return &EVMGasPriceProvider{provider: provider}
24+
}
25+
26+
func (p *EVMGasPriceProvider) GetGasPrices(ctx context.Context) (maxFeePerGas, maxPriorityFeePerGas *big.Int, err error) {
27+
var maxPriorityFeePerGasStr string
28+
if err := p.provider.RPC().CallContext(ctx, &maxPriorityFeePerGasStr, "eth_maxPriorityFeePerGas"); err != nil {
29+
return nil, nil, err
30+
}
31+
32+
maxPriorityFeePerGas, ok := new(big.Int).SetString(maxPriorityFeePerGasStr, 0)
33+
if !ok {
34+
return nil, nil, fmt.Errorf("failed to parse maxPriorityFeePerGas: %s", maxPriorityFeePerGasStr)
35+
}
36+
logger.Debug("fetched maxPriorityFeePerGas", "maxPriorityFeePerGas", maxPriorityFeePerGas.String())
37+
38+
// Get the latest block to read its base fee
39+
block, err := p.provider.BlockByNumber(ctx, nil)
40+
if err != nil {
41+
return nil, nil, err
42+
}
43+
blockBaseFee := block.BaseFee()
44+
logger.Debug("fetched block base fee", "baseFee", blockBaseFee.String())
45+
46+
maxFeePerGas = blockBaseFee.Add(blockBaseFee, maxPriorityFeePerGas)
47+
48+
return maxFeePerGas, maxPriorityFeePerGas, nil
49+
}
50+
51+
type PolygonGasPriceProvider struct {
52+
chainId *big.Int
53+
}
54+
55+
func NewPolygonGasPriceProvider(chainId *big.Int) *PolygonGasPriceProvider {
56+
return &PolygonGasPriceProvider{chainId: chainId}
57+
}
58+
59+
func (p *PolygonGasPriceProvider) GetGasPrices(ctx context.Context) (maxFeePerGas, maxPriorityFeePerGas *big.Int, err error) {
60+
var resp *http.Response
61+
62+
if p.chainId == nil {
63+
return nil, nil, fmt.Errorf("chain ID is nil")
64+
}
65+
66+
switch {
67+
case p.chainId.Uint64() == 137:
68+
resp, err = http.Get("https://gasstation.polygon.technology/v2")
69+
case p.chainId.Uint64() == 80002:
70+
resp, err = http.Get("https://gasstation.polygon.technology/amoy")
71+
default:
72+
return nil, nil, fmt.Errorf("unsupported chain ID: %v", p.chainId)
73+
}
74+
75+
if err != nil {
76+
return nil, nil, fmt.Errorf("error fetching data: %v", err)
77+
}
78+
defer resp.Body.Close()
79+
80+
body, err := io.ReadAll(resp.Body)
81+
if err != nil {
82+
return nil, nil, fmt.Errorf("error reading response body: %v", err)
83+
}
84+
85+
var gasData struct {
86+
Fast struct {
87+
MaxPriorityFee decimal.Decimal `json:"maxPriorityFee"`
88+
MaxFee decimal.Decimal `json:"maxFee"`
89+
} `json:"fast"`
90+
}
91+
92+
err = json.Unmarshal(body, &gasData)
93+
if err != nil {
94+
return nil, nil, fmt.Errorf("error unmarshalling JSON: %v", err)
95+
}
96+
97+
gweiMult := decimal.NewFromInt(1e9)
98+
99+
maxFeePerGas = gasData.Fast.MaxFee.Mul(gweiMult).BigInt()
100+
maxPriorityFeePerGas = gasData.Fast.MaxPriorityFee.Mul(gweiMult).BigInt()
101+
102+
return maxFeePerGas, maxPriorityFeePerGas, nil
103+
}
104+
105+
type MockGasPriceProvider struct {
106+
maxFeePerGas *big.Int
107+
maxPriorityFeePerGas *big.Int
108+
}
109+
110+
func NewMockGasPriceProvider(maxFeePerGas, maxPriorityFeePerGas *big.Int) *MockGasPriceProvider {
111+
return &MockGasPriceProvider{maxFeePerGas: maxFeePerGas, maxPriorityFeePerGas: maxPriorityFeePerGas}
112+
}
113+
114+
func (p *MockGasPriceProvider) GetGasPrices(ctx context.Context) (maxFeePerGas, maxPriorityFeePerGas *big.Int, err error) {
115+
return p.maxFeePerGas, p.maxPriorityFeePerGas, nil
116+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package userop
2+
3+
import (
4+
"context"
5+
"math/big"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestGetPolygonGasPrices(t *testing.T) {
12+
t.Skip("should not depend on external services in test environment")
13+
14+
tests := []struct {
15+
name string
16+
chainId *big.Int
17+
wantErr bool
18+
}{
19+
{
20+
name: "should return gas prices for Polygon",
21+
chainId: big.NewInt(137),
22+
wantErr: false,
23+
},
24+
{
25+
name: "should return gas prices for Amoy",
26+
chainId: big.NewInt(80002),
27+
wantErr: false,
28+
},
29+
{
30+
name: "should return error for other chain",
31+
chainId: big.NewInt(42),
32+
wantErr: true,
33+
},
34+
{
35+
name: "should return error if chainId is nil",
36+
chainId: nil,
37+
wantErr: true,
38+
},
39+
}
40+
41+
for _, tt := range tests {
42+
t.Run(tt.name, func(t *testing.T) {
43+
provider := NewPolygonGasPriceProvider(tt.chainId)
44+
45+
ctx := context.Background()
46+
maxFee, maxPriorityFee, err := provider.GetGasPrices(ctx)
47+
if tt.wantErr {
48+
require.Error(t, err)
49+
} else {
50+
require.NoError(t, err)
51+
require.NotNil(t, maxFee)
52+
require.NotNil(t, maxPriorityFee)
53+
}
54+
})
55+
}
56+
}

pkg/userop/gas_test.go

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,71 @@
11
package userop
22

33
import (
4+
"context"
45
"math/big"
56
"testing"
67

8+
"github.com/shopspring/decimal"
79
"github.com/stretchr/testify/require"
810
)
911

10-
func TestGetPolygonGasPrices(t *testing.T) {
11-
t.Run("Should return gas prices for Polygon", func(t *testing.T) {
12-
chainId := big.NewInt(137)
13-
maxFee, maxPriorityFee, err := getPolygonGasPrices(chainId)
14-
require.NoError(t, err)
15-
require.NotNil(t, maxFee)
16-
require.NotNil(t, maxPriorityFee)
17-
// Gas station states that for Polygon Mainnet the maxPriorityFee is at least 30 Gwei
18-
require.True(t, maxPriorityFee.Cmp(big.NewInt(30).Exp(big.NewInt(10), big.NewInt(9), nil)) >= 0)
19-
})
12+
func TestGetGasPricesMiddleware(t *testing.T) {
13+
t.Run("gas multipliers are applied correctly", func(t *testing.T) {
14+
mockedMaxFeePerGas := big.NewInt(400200400200)
15+
mockedMaxPriorityFeePerGas := big.NewInt(42424242)
16+
mockedGasPricesProvider := NewMockGasPriceProvider(mockedMaxFeePerGas, mockedMaxPriorityFeePerGas)
2017

21-
t.Run("Should return gas prices for Amoy", func(t *testing.T) {
22-
chainId := big.NewInt(80002)
23-
maxFee, maxPriorityFee, err := getPolygonGasPrices(chainId)
24-
require.NoError(t, err)
25-
require.NotNil(t, maxFee)
26-
require.NotNil(t, maxPriorityFee)
27-
})
18+
tests := []struct {
19+
name string
20+
gasMultipliers GasConfig
21+
wantMaxFeePerGas *big.Int
22+
wantMaxPriorityFeePerGas *big.Int
23+
}{
24+
{
25+
name: "no multipliers supplied result in 0 gas prices",
26+
gasMultipliers: GasConfig{},
27+
wantMaxFeePerGas: big.NewInt(0),
28+
wantMaxPriorityFeePerGas: big.NewInt(0),
29+
},
30+
{
31+
name: "1 multipliers do not change gas prices",
32+
gasMultipliers: GasConfig{
33+
MaxPriorityFeePerGasMultiplier: decimal.NewFromFloat(1),
34+
MaxFeePerGasMultiplier: decimal.NewFromFloat(1),
35+
},
36+
wantMaxFeePerGas: mockedMaxFeePerGas,
37+
wantMaxPriorityFeePerGas: mockedMaxPriorityFeePerGas,
38+
},
39+
{
40+
name: "1.5 multipliers are applied correctly",
41+
gasMultipliers: GasConfig{
42+
MaxPriorityFeePerGasMultiplier: decimal.NewFromFloat(1.5),
43+
MaxFeePerGasMultiplier: decimal.NewFromFloat(1.5),
44+
},
45+
wantMaxFeePerGas: big.NewInt(600300600300),
46+
wantMaxPriorityFeePerGas: big.NewInt(63636363),
47+
},
48+
{
49+
name: "2.25 multipliers are applied correctly",
50+
gasMultipliers: GasConfig{
51+
MaxPriorityFeePerGasMultiplier: decimal.NewFromFloat(2.25),
52+
MaxFeePerGasMultiplier: decimal.NewFromFloat(2.25),
53+
},
54+
wantMaxFeePerGas: big.NewInt(900450900450),
55+
wantMaxPriorityFeePerGas: big.NewInt(95454544),
56+
},
57+
}
2858

29-
t.Run("Should return error for other chain", func(t *testing.T) {
30-
chainId := big.NewInt(42)
31-
_, _, err := getPolygonGasPrices(chainId)
32-
require.Error(t, err)
33-
})
59+
ctx := context.Background()
60+
61+
for _, tt := range tests {
62+
t.Run(tt.name, func(t *testing.T) {
63+
gotMaxFeePerGas, gotMaxPriorityFeePerGas, err := getGasPricesAndApplyMultipliers(ctx, mockedGasPricesProvider, tt.gasMultipliers)
64+
require.NoError(t, err)
3465

35-
t.Run("Should return error if chainId is nil", func(t *testing.T) {
36-
_, _, err := getPolygonGasPrices(nil)
37-
require.Error(t, err)
66+
require.True(t, gotMaxFeePerGas.Cmp(tt.wantMaxFeePerGas) == 0)
67+
require.True(t, gotMaxPriorityFeePerGas.Cmp(tt.wantMaxPriorityFeePerGas) == 0)
68+
})
69+
}
3870
})
3971
}

0 commit comments

Comments
 (0)