Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/blockchaincmd/change_weight.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func setWeight(_ *cobra.Command, args []string) error {
return fmt.Errorf("unable to find Validator Manager address")
}
validatorManagerAddress = sc.Networks[network.Name()].ValidatorManagerAddress
validationID, err := validator.GetValidationID(rpcURL, nodeID, validatorManagerAddress)
validationID, err := validator.GetValidationID(rpcURL, common.HexToAddress(validatorManagerAddress), nodeID)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/blockchaincmd/remove_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func removeValidator(_ *cobra.Command, args []string) error {
}
}
validatorManagerAddress = sc.Networks[network.Name()].ValidatorManagerAddress
validationID, err := validatorsdk.GetRegisteredValidator(
validationID, err := validatorsdk.GetValidationID(
rpcURL,
common.HexToAddress(validatorManagerAddress),
nodeID,
Expand Down
2 changes: 1 addition & 1 deletion cmd/validatorcmd/getBalance.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func getNodeValidationID(
return ids.Empty, false, err
}
managerAddress := common.HexToAddress(validatorManagerAddress)
validationID, err = validator.GetRegisteredValidator(rpcURL, managerAddress, nodeID)
validationID, err = validator.GetValidationID(rpcURL, managerAddress, nodeID)
if err != nil {
return ids.Empty, false, err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/validatorcmd/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func list(_ *cobra.Command, args []string) error {
return err
}
balance := uint64(0)
validationID, err := validator.GetRegisteredValidator(rpcURL, managerAddress, nodeID)
validationID, err := validator.GetValidationID(rpcURL, managerAddress, nodeID)
if err != nil {
ux.Logger.RedXToUser("could not get validation ID for node %s due to %s", nodeID, err)
} else {
Expand Down
15 changes: 15 additions & 0 deletions pkg/contract/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,21 @@ func CallToMethod(
return out, nil
}

func GetSmartContractCallResult[T any](methodName string, out []interface{}) (T, error) {
empty := new(T)
if len(out) == 0 {
return *empty, fmt.Errorf("error at %s call: no return value", methodName)
}
if len(out) != 1 {
return *empty, fmt.Errorf("error at %s call: expected 1 return value, got %d", methodName, len(out))
}
received, typeIsOk := out[0].(T)
if !typeIsOk {
return *empty, fmt.Errorf("error at %s call, expected %T, got %T", methodName, *empty, out[0])
}
return received, nil
}

func DeployContract(
rpcURL string,
privateKey string,
Expand Down
14 changes: 2 additions & 12 deletions pkg/contract/ownable.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@
// See the file LICENSE for licensing terms.
package contract

import (
_ "embed"
"fmt"

"github.com/ethereum/go-ethereum/common"
)
import "github.com/ethereum/go-ethereum/common"

// GetContractOwner gets owner for https://docs.openzeppelin.com/contracts/2.x/api/ownership#Ownable-owner contracts
func GetContractOwner(
Expand All @@ -22,10 +17,5 @@ func GetContractOwner(
if err != nil {
return common.Address{}, err
}

ownerAddr, ok := out[0].(common.Address)
if !ok {
return common.Address{}, fmt.Errorf("error at owner() call, expected common.Address, got %T", out[0])
}
return ownerAddr, nil
return GetSmartContractCallResult[common.Address]("owner", out)
}
39 changes: 9 additions & 30 deletions pkg/ictt/operate.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,7 @@ func ERC20TokenHomeGetTokenAddress(
if err != nil {
return common.Address{}, err
}
tokenAddress, b := out[0].(common.Address)
if !b {
return common.Address{}, fmt.Errorf("error at token call, expected common.Address, got %T", out[0])
}
return tokenAddress, nil
return contract.GetSmartContractCallResult[common.Address]("token", out)
}

func NativeTokenHomeGetTokenAddress(
Expand All @@ -74,11 +70,7 @@ func NativeTokenHomeGetTokenAddress(
if err != nil {
return common.Address{}, err
}
tokenAddress, b := out[0].(common.Address)
if !b {
return common.Address{}, fmt.Errorf("error at wrappedToken call, expected common.Address, got %T", out[0])
}
return tokenAddress, nil
return contract.GetSmartContractCallResult[common.Address]("wrappedToken", out)
}

func TokenRemoteIsCollateralized(
Expand All @@ -93,11 +85,7 @@ func TokenRemoteIsCollateralized(
if err != nil {
return false, err
}
isCollateralized, b := out[0].(bool)
if !b {
return false, fmt.Errorf("error at isCollateralized call, expected bool, got %T", out[0])
}
return isCollateralized, nil
return contract.GetSmartContractCallResult[bool]("isCollateralized", out)
}

func TokenHomeGetDecimals(
Expand All @@ -112,11 +100,7 @@ func TokenHomeGetDecimals(
if err != nil {
return 0, err
}
decimals, b := out[0].(uint8)
if !b {
return 0, fmt.Errorf("error at tokenDecimals, expected uint8, got %T", out[0])
}
return decimals, nil
return contract.GetSmartContractCallResult[uint8]("tokenDecimals", out)
}

type RegisteredRemote struct {
Expand Down Expand Up @@ -146,6 +130,9 @@ func TokenHomeGetRegisteredRemote(
registeredRemote RegisteredRemote
b bool
)
if len(out) != 4 {
return RegisteredRemote{}, fmt.Errorf("error at registeredRemotes call, expected 4 return values, got %d", len(out))
}
registeredRemote.Registered, b = out[0].(bool)
if !b {
return RegisteredRemote{}, fmt.Errorf("error at registeredRemotes call, expected bool, got %T", out[0])
Expand Down Expand Up @@ -177,11 +164,7 @@ func ERC20TokenRemoteGetTokenHomeAddress(
if err != nil {
return common.Address{}, err
}
tokenHubAddress, b := out[0].(common.Address)
if !b {
return common.Address{}, fmt.Errorf("error at tokenHubAddress call, expected common.Address, got %T", out[0])
}
return tokenHubAddress, nil
return contract.GetSmartContractCallResult[common.Address]("tokenHomeAddress", out)
}

func NativeTokenRemoteGetTotalNativeAssetSupply(
Expand All @@ -196,11 +179,7 @@ func NativeTokenRemoteGetTotalNativeAssetSupply(
if err != nil {
return nil, err
}
supply, b := out[0].(*big.Int)
if !b {
return nil, fmt.Errorf("error at totalNativeAssetSupply, expected *big.Int, got %T", out[0])
}
return supply, nil
return contract.GetSmartContractCallResult[*big.Int]("totalNativeAssetSupply", out)
}

func ERC20TokenHomeSend(
Expand Down
13 changes: 2 additions & 11 deletions pkg/interchain/operate.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package interchain

import (
_ "embed"
"fmt"
"math/big"

"github.com/ava-labs/avalanche-cli/pkg/contract"
Expand All @@ -27,11 +26,7 @@ func GetNextMessageID(
if err != nil {
return ids.Empty, err
}
received, b := out[0].([32]byte)
if !b {
return ids.Empty, fmt.Errorf("error at getNextMessageID call, expected ids.ID, got %T", out[0])
}
return received, nil
return contract.GetSmartContractCallResult[[32]byte]("getNextMessageID", out)
}

func MessageReceived(
Expand All @@ -48,11 +43,7 @@ func MessageReceived(
if err != nil {
return false, err
}
received, b := out[0].(bool)
if !b {
return false, fmt.Errorf("error at messageReceived call, expected bool, got %T", out[0])
}
return received, nil
return contract.GetSmartContractCallResult[bool]("messageReceived", out)
}

func SendCrossChainMessage(
Expand Down
7 changes: 1 addition & 6 deletions pkg/precompiles/allowlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package precompiles

import (
_ "embed"
"fmt"
"math/big"

"github.com/ava-labs/avalanche-cli/pkg/contract"
Expand Down Expand Up @@ -109,9 +108,5 @@ func ReadAllowList(
if err != nil {
return nil, err
}
role, b := out[0].(*big.Int)
if !b {
return nil, fmt.Errorf("error at readAllowList, expected *big.Int, got %T", out[0])
}
return role, nil
return contract.GetSmartContractCallResult[*big.Int]("readAllowList", out)
}
10 changes: 1 addition & 9 deletions pkg/precompiles/warp.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package precompiles

import (
_ "embed"
"fmt"

"github.com/ava-labs/avalanche-cli/pkg/contract"
"github.com/ava-labs/avalanchego/ids"
Expand All @@ -21,12 +20,5 @@ func WarpPrecompileGetBlockchainID(
if err != nil {
return ids.Empty, err
}
if len(out) == 0 {
return ids.Empty, fmt.Errorf("error at getBlockchainID call: no return value")
}
received, ok := out[0].([32]byte)
if !ok {
return ids.Empty, fmt.Errorf("error at getBlockchainID call, expected ids.ID, got %T", out[0])
}
return received, nil
return contract.GetSmartContractCallResult[[32]byte]("getBlockchainID", out)
}
7 changes: 1 addition & 6 deletions pkg/validatormanager/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package validatormanager

import (
_ "embed"
"fmt"
"math/big"

"github.com/ava-labs/avalanche-cli/pkg/contract"
Expand Down Expand Up @@ -47,11 +46,7 @@ func GetProxyValidatorManager(
if err != nil {
return common.Address{}, err
}
validatorManagerAddress, b := out[0].(common.Address)
if !b {
return common.Address{}, fmt.Errorf("error obtaining proxy implementation, expected common.Address, got %T", out[0])
}
return validatorManagerAddress, nil
return contract.GetSmartContractCallResult[common.Address]("getProxyImplementation", out)
}

func ProxyHasValidatorManagerSet(
Expand Down
8 changes: 2 additions & 6 deletions pkg/validatormanager/registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func GetRegisterL1ValidatorMessage(
)
if registerSubnetValidatorUnsignedMessage == nil {
if alreadyInitialized {
validationID, err = validator.GetRegisteredValidator(
validationID, err = validator.GetValidationID(
rpcURL,
managerAddress,
nodeID,
Expand Down Expand Up @@ -299,11 +299,7 @@ func PoSWeightToValue(
if err != nil {
return nil, err
}
value, b := out[0].(*big.Int)
if !b {
return nil, fmt.Errorf("error at weightToValue, expected *big.Int, got %T", out[0])
}
return value, nil
return contract.GetSmartContractCallResult[*big.Int]("weightToValue", out)
}

func GetPChainL1ValidatorRegistrationMessage(
Expand Down
2 changes: 1 addition & 1 deletion pkg/validatormanager/removal.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func InitValidatorRemoval(
}
managerAddress := common.HexToAddress(validatorManagerAddressStr)
ownerAddress := common.HexToAddress(ownerAddressStr)
validationID, err := validator.GetRegisteredValidator(
validationID, err := validator.GetValidationID(
rpcURL,
managerAddress,
nodeID,
Expand Down
2 changes: 1 addition & 1 deletion pkg/validatormanager/weight_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func InitValidatorWeightChange(
}
managerAddress := common.HexToAddress(validatorManagerAddressStr)
ownerAddress := common.HexToAddress(ownerAddressStr)
validationID, err := validator.GetRegisteredValidator(
validationID, err := validator.GetValidationID(
rpcURL,
managerAddress,
nodeID,
Expand Down
21 changes: 5 additions & 16 deletions sdk/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
package validator

import (
"fmt"

"github.com/ava-labs/avalanche-cli/pkg/contract"
"github.com/ava-labs/avalanche-cli/sdk/network"
"github.com/ava-labs/avalanche-cli/sdk/utils"
"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/vms/platformvm"
"github.com/ava-labs/avalanchego/vms/platformvm/api"

"github.com/ethereum/go-ethereum/common"
"golang.org/x/exp/maps"
)
Expand Down Expand Up @@ -70,12 +69,9 @@ func GetValidatorInfo(net network.Network, validationID ids.ID) (platformvm.L1Va
return vdrInfo, nil
}

func GetValidationID(rpcURL string, nodeID ids.NodeID, validatorManagerAddressStr string) (ids.ID, error) {
managerAddress := common.HexToAddress(validatorManagerAddressStr)
return GetRegisteredValidator(rpcURL, managerAddress, nodeID)
}

func GetRegisteredValidator(
// Returns the validation ID for the Node ID, as registered at the validator manager
// Will return ids.Empty in case it is not registered
func GetValidationID(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add description to this function?

// Verifies
/**
* @notice Returns a validation ID registered to the given nodeID
* @param nodeID ID of the node associated with the validation ID
*/

rpcURL string,
managerAddress common.Address,
nodeID ids.NodeID,
Expand All @@ -89,14 +85,7 @@ func GetRegisteredValidator(
if err != nil {
return ids.Empty, err
}
if len(out) == 0 {
return ids.Empty, fmt.Errorf("error at registeredValidators call, no value returned")
}
validatorID, typeIsOk := out[0].([32]byte)
if !typeIsOk {
return ids.Empty, fmt.Errorf("error at registeredValidators call, expected [32]byte, got %T", out[0])
}
return validatorID, nil
return contract.GetSmartContractCallResult[[32]byte]("registeredValidators", out)
}

func IsSovereignValidator(
Expand Down
1 change: 1 addition & 0 deletions sdk/validatormanager/validator_manager_pos.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package validatormanager
import (
"github.com/ava-labs/avalanche-cli/pkg/contract"
"github.com/ava-labs/subnet-evm/core/types"

"github.com/ethereum/go-ethereum/common"
)

Expand Down
Loading