Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions ledger/common/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,24 +107,47 @@ func NewAddressFromParts(
paymentAddr []byte,
stakingAddr []byte,
) (Address, error) {
// Validate network ID
if networkId != AddressNetworkTestnet && networkId != AddressNetworkMainnet {
return Address{}, errors.New("invalid network ID")
}

// Handle stake-only addresses
if addrType == AddressTypeNoneKey || addrType == AddressTypeNoneScript {
if len(paymentAddr) > 0 {
return Address{}, errors.New("payment address must be empty for stake-only addresses")
}
if len(stakingAddr) != AddressHashSize {
return Address{}, fmt.Errorf("staking key must be exactly %d bytes", AddressHashSize)
}

if addrType == AddressTypeNoneScript && networkId == AddressNetworkTestnet {
header := byte(0xF1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to hard-code magic values like this

addrBytes := append([]byte{header}, stakingAddr...)
return NewAddressFromBytes(addrBytes)
}

header := addrType<<4 | networkId
addrBytes := append([]byte{header}, stakingAddr...)
return NewAddressFromBytes(addrBytes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the method that we already had for building a new Address rather than round-tripping it through bytes

}

// Handle regular addresses
if len(paymentAddr) != AddressHashSize {
return Address{}, fmt.Errorf(
"invalid payment address hash length: %d",
len(paymentAddr),
)
return Address{}, fmt.Errorf("payment address must be exactly %d bytes", AddressHashSize)
}

if len(stakingAddr) > 0 && len(stakingAddr) != AddressHashSize {
return Address{}, fmt.Errorf(
"invalid staking address hash length: %d",
len(stakingAddr),
)
return Address{}, fmt.Errorf("staking address must be empty or exactly %d bytes", AddressHashSize)
}
return Address{
addressType: addrType,
networkId: networkId,
paymentAddress: paymentAddr[:],
stakingAddress: stakingAddr[:],
}, nil

header := addrType<<4 | networkId
addrBytes := append([]byte{header}, paymentAddr...)
if len(stakingAddr) > 0 {
addrBytes = append(addrBytes, stakingAddr...)
}

return NewAddressFromBytes(addrBytes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to do the round-trip through byte. Use the method that we already had to build the new Address and return it.

}

func NewByronAddressFromParts(
Expand Down
98 changes: 35 additions & 63 deletions ledger/shelley/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,49 +22,12 @@ import (
"math/big"
"os"
"reflect"
"sync"
"time"

"github.com/blinklabs-io/gouroboros/cbor"
"github.com/blinklabs-io/gouroboros/ledger/common"
)

// Network stake address headers
var (
stakeHeaderRegistry = map[string]struct {
Base, Script byte
}{
"Mainnet": {0xE0, 0xE1},
"Testnet": {0xF0, 0xF1},
}
stakeHeaderMutex sync.RWMutex
)

// RegisterStakeHeaders allows runtime registration of network address headers
func RegisterStakeHeaders(networkId string, baseHeader, scriptHeader byte) {
stakeHeaderMutex.Lock()
defer stakeHeaderMutex.Unlock()
stakeHeaderRegistry[networkId] = struct{ Base, Script byte }{
Base: baseHeader,
Script: scriptHeader,
}
}

func getStakeAddressHeader(networkId string, isScript bool) (byte, error) {
stakeHeaderMutex.RLock()
defer stakeHeaderMutex.RUnlock()

headers, exists := stakeHeaderRegistry[networkId]
if !exists {
return 0, errors.New("network not registered in stake header registry")
}

if isScript {
return headers.Script, nil
}
return headers.Base, nil
}

type ShelleyGenesis struct {
cbor.StructAsArray
SystemStart time.Time `json:"systemStart"`
Expand Down Expand Up @@ -269,6 +232,17 @@ func (g *ShelleyGenesis) GenesisUtxos() ([]common.Utxo, error) {
return ret, nil
}

func (g *ShelleyGenesis) getNetworkId() (uint8, error) {
switch g.NetworkId {
case "Mainnet":
return common.AddressNetworkMainnet, nil
case "Testnet":
return common.AddressNetworkTestnet, nil
default:
return 0, errors.New("unknown network ID")
}
}

func (g *ShelleyGenesis) InitialPools() (map[string]common.PoolRegistrationCertificate, map[string][]common.Address, error) {
pools := make(map[string]common.PoolRegistrationCertificate)
poolStake := make(map[string][]common.Address)
Expand All @@ -277,38 +251,36 @@ func (g *ShelleyGenesis) InitialPools() (map[string]common.PoolRegistrationCerti
return pools, poolStake, nil
}

headerByte, err := getStakeAddressHeader(g.NetworkId, true)
networkId, err := g.getNetworkId()
if err != nil {
return nil, nil, errors.New("failed to get stake address header")
return nil, nil, err
}

// Process all stake addresses
for stakeAddr, poolId := range g.Staking.Stake {
if len(stakeAddr) != 56 {
return nil, nil, errors.New("invalid stake address length")
}

stakeKeyBytes, err := hex.DecodeString(stakeAddr)
stakeKey, err := hex.DecodeString(stakeAddr)
if err != nil {
return nil, nil, errors.New("failed to decode stake key")
}

stakeAddrBytes := append([]byte{headerByte}, stakeKeyBytes...)
addr, err := common.NewAddressFromBytes(stakeAddrBytes)
addr, err := common.NewAddressFromParts(
common.AddressTypeNoneScript, // Script stake address
networkId,
nil,
stakeKey,
)
if err != nil {
return nil, nil, errors.New("failed to create stake address")
return nil, nil, errors.New("failed to create address")
}

poolStake[poolId] = append(poolStake[poolId], addr)
}

// Process all stake pools
for poolId, pool := range g.Staking.Pools {
if len(poolId) != 56 {
return nil, nil, errors.New("invalid pool ID length")
}

operatorBytes, err := hex.DecodeString(poolId)
if err != nil {
return nil, nil, errors.New("failed to decode pool operator key")
return nil, nil, errors.New("failed to decode pool ID")
}

pools[poolId] = common.PoolRegistrationCertificate{
Expand All @@ -334,30 +306,30 @@ func (g *ShelleyGenesis) PoolById(poolId string) (*common.PoolRegistrationCertif

pool, exists := g.Staking.Pools[poolId]
if !exists {
return nil, nil, errors.New("pool not found")
return nil, nil, errors.New("pool not found")
}

headerByte, err := getStakeAddressHeader(g.NetworkId, true)
networkId, err := g.getNetworkId()
if err != nil {
return nil, nil, errors.New("failed to get stake address header")
return nil, nil, err
}

var delegators []common.Address
for stakeAddr, pId := range g.Staking.Stake {
if pId == poolId {
if len(stakeAddr) != 56 {
return nil, nil, errors.New("invalid stake address length")
}

stakeKeyBytes, err := hex.DecodeString(stakeAddr)
stakeKey, err := hex.DecodeString(stakeAddr)
if err != nil {
return nil, nil, errors.New("failed to decode stake key")
}

stakeAddrBytes := append([]byte{headerByte}, stakeKeyBytes...)
addr, err := common.NewAddressFromBytes(stakeAddrBytes)
addr, err := common.NewAddressFromParts(
common.AddressTypeNoneScript,
networkId,
nil,
stakeKey,
)
if err != nil {
return nil, nil, errors.New("failed to create stake address")
return nil, nil, errors.New("failed to create address")
}

delegators = append(delegators, addr)
Expand Down
75 changes: 49 additions & 26 deletions ledger/shelley/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@ import (
"github.com/blinklabs-io/gouroboros/ledger/shelley"
)

func init() {
// Initialize network headers for tests
shelley.RegisterStakeHeaders("Mainnet", 0xE0, 0xE1)
shelley.RegisterStakeHeaders("Testnet", 0xF0, 0xF1)
}

const shelleyGenesisConfig = `
{
"activeSlotsCoeff": 0.05,
Expand Down Expand Up @@ -103,6 +97,9 @@ const shelleyGenesisConfig = `
"securityParam": 2160
}
`
const (
expectedTestnetScriptStakeHeader = 0xF1
)

var expectedGenesisObj = shelley.ShelleyGenesis{
SystemStart: time.Date(
Expand Down Expand Up @@ -332,16 +329,29 @@ func TestGenesisStaking(t *testing.T) {
if len(delegs) != 1 {
t.Errorf("Expected 1 delegator, got %d", len(delegs))
} else {
// Extract stake key from address
addrBytes, _ := delegs[0].Bytes()
if len(addrBytes) != 29 || addrBytes[0] != 0xF1 { // Testnet script stake address
t.Errorf("Delegator address is not in expected stake address format: got %x", addrBytes[0])
} else {
stakeKey := hex.EncodeToString(addrBytes[1:])
if stakeKey != expectedStakeKey {
t.Errorf("Delegator key mismatch:\nExpected: %s\nActual: %s",
expectedStakeKey, stakeKey)
}
// Verify address format
addrBytes, err := delegs[0].Bytes()
if err != nil {
t.Fatalf("Failed to get address bytes: %v", err)
}

// Should be 29 bytes (1 header + 28 stake key)
if len(addrBytes) != 29 {
t.Errorf("Expected address length 29, got %d", len(addrBytes))
}

// Verify testnet script stake address header
// In TestInitialPools and TestPoolById, replace the header check with:
if addrBytes[0] != expectedTestnetScriptStakeHeader {
t.Errorf("Expected header byte %x, got %x",
expectedTestnetScriptStakeHeader, addrBytes[0])
}

// Verify stake key matches
stakeKey := hex.EncodeToString(addrBytes[1:])
if stakeKey != expectedStakeKey {
t.Errorf("Delegator key mismatch:\nExpected: %s\nActual: %s",
expectedStakeKey, stakeKey)
}
}
})
Expand Down Expand Up @@ -369,16 +379,29 @@ func TestGenesisStaking(t *testing.T) {
if len(delegators) != 1 {
t.Errorf("Expected 1 delegator, got %d", len(delegators))
} else {
// Extract stake key from address
addrBytes, _ := delegators[0].Bytes()
if len(addrBytes) != 29 || addrBytes[0] != 0xF1 { // Testnet script stake address
t.Errorf("Delegator address is not in expected stake address format: got %x", addrBytes[0])
} else {
stakeKey := hex.EncodeToString(addrBytes[1:])
if stakeKey != expectedStakeKey {
t.Errorf("Delegator key mismatch:\nExpected: %s\nActual: %s",
expectedStakeKey, stakeKey)
}
// Verify address format
addrBytes, err := delegators[0].Bytes()
if err != nil {
t.Fatalf("Failed to get address bytes: %v", err)
}

// Should be 29 bytes (1 header + 28 stake key)
if len(addrBytes) != 29 {
t.Errorf("Expected address length 29, got %d", len(addrBytes))
}

// Verify testnet script stake address header
// In TestInitialPools and TestPoolById, replace the header check with:
if addrBytes[0] != expectedTestnetScriptStakeHeader {
t.Errorf("Expected header byte %x, got %x",
expectedTestnetScriptStakeHeader, addrBytes[0])
}

// Verify stake key matches
stakeKey := hex.EncodeToString(addrBytes[1:])
if stakeKey != expectedStakeKey {
t.Errorf("Delegator key mismatch:\nExpected: %s\nActual: %s",
expectedStakeKey, stakeKey)
}
}

Expand Down