From 6c78d72fe3dd43c7cfd923a35e293acc70b0833b Mon Sep 17 00:00:00 2001 From: Chris Gianelloni Date: Tue, 29 Oct 2024 08:18:47 -0400 Subject: [PATCH] fix: guard possible nil panics Signed-off-by: Chris Gianelloni --- cmd/gouroboros/chainsync.go | 10 +++++++++- ledger/allegra/pparams_test.go | 6 +++++- ledger/alonzo/pparams.go | 7 ++++++- ledger/alonzo/pparams_test.go | 6 +++++- ledger/conway/pparams.go | 3 +++ ledger/mary/pparams_test.go | 6 +++++- ledger/shelley/pparams.go | 3 +++ ledger/shelley/pparams_test.go | 6 +++++- ledger/verify_block.go | 5 +++++ protocol/handshake/server.go | 31 ++++++++++++++++++++++++++++++ protocol/localstatequery/client.go | 6 +++++- 11 files changed, 82 insertions(+), 7 deletions(-) diff --git a/cmd/gouroboros/chainsync.go b/cmd/gouroboros/chainsync.go index c5b56107..ccad2db1 100644 --- a/cmd/gouroboros/chainsync.go +++ b/cmd/gouroboros/chainsync.go @@ -170,7 +170,7 @@ func testChainSync(f *globalFlags) { os.Exit(1) } }() - oConn, err = ouroboros.New( + o, err := ouroboros.New( ouroboros.WithConnection(conn), ouroboros.WithNetworkMagic(uint32(f.networkMagic)), ouroboros.WithErrorChan(errorChan), @@ -183,6 +183,11 @@ func testChainSync(f *globalFlags) { fmt.Printf("ERROR: %s\n", err) os.Exit(1) } + if o == nil { + fmt.Println("ERROR: empty connection") + os.Exit(1) + } + oConn = o var point common.Point if chainSyncFlags.tip { @@ -260,6 +265,9 @@ func chainSyncRollForwardHandler( blockSlot := v.SlotNumber() blockHash, _ := hex.DecodeString(v.Hash()) var err error + if oConn == nil { + return fmt.Errorf("empty ouroboros connection, aborting!") + } block, err = oConn.BlockFetch().Client.GetBlock(common.NewPoint(blockSlot, blockHash)) if err != nil { return err diff --git a/ledger/allegra/pparams_test.go b/ledger/allegra/pparams_test.go index e575446c..b23c7fa2 100644 --- a/ledger/allegra/pparams_test.go +++ b/ledger/allegra/pparams_test.go @@ -132,6 +132,10 @@ func TestAllegraUtxorpc(t *testing.T) { result := inputParams.Utxorpc() if !reflect.DeepEqual(result, expectedUtxorpc) { - t.Fatalf("Utxorpc() test failed for Allegra:\nExpected: %#v\nGot: %#v", expectedUtxorpc, result) + t.Fatalf( + "Utxorpc() test failed for Allegra:\nExpected: %#v\nGot: %#v", + expectedUtxorpc, + result, + ) } } diff --git a/ledger/alonzo/pparams.go b/ledger/alonzo/pparams.go index 804fa8dc..d70a1fe9 100644 --- a/ledger/alonzo/pparams.go +++ b/ledger/alonzo/pparams.go @@ -70,6 +70,9 @@ func (p *AlonzoProtocolParameters) Update( } func (p *AlonzoProtocolParameters) UpdateFromGenesis(genesis *AlonzoGenesis) { + if genesis == nil { + return + } // XXX: do we need to convert this? p.AdaPerUtxoByte = genesis.LovelacePerUtxoWord p.MaxValueSize = genesis.MaxValueSize @@ -144,7 +147,9 @@ func (p *AlonzoProtocolParameters) Utxorpc() *cardano.PParams { MaxValueSize: uint64(p.MaxValueSize), CollateralPercentage: uint64(p.CollateralPercentage), MaxCollateralInputs: uint64(p.MaxCollateralInputs), - CostModels: common.ConvertToUtxorpcCardanoCostModels(p.CostModels), + CostModels: common.ConvertToUtxorpcCardanoCostModels( + p.CostModels, + ), Prices: &cardano.ExPrices{ Memory: &cardano.RationalNumber{ Numerator: int32(p.ExecutionCosts.MemPrice.Num().Int64()), diff --git a/ledger/alonzo/pparams_test.go b/ledger/alonzo/pparams_test.go index 4caf5786..ddc8d1c6 100644 --- a/ledger/alonzo/pparams_test.go +++ b/ledger/alonzo/pparams_test.go @@ -294,6 +294,10 @@ func TestAlonzoUtxorpc(t *testing.T) { result := inputParams.Utxorpc() if !reflect.DeepEqual(result, expectedUtxorpc) { - t.Fatalf("Utxorpc() test failed for Alonzo:\nExpected: %#v\nGot: %#v", expectedUtxorpc, result) + t.Fatalf( + "Utxorpc() test failed for Alonzo:\nExpected: %#v\nGot: %#v", + expectedUtxorpc, + result, + ) } } diff --git a/ledger/conway/pparams.go b/ledger/conway/pparams.go index 74338c5c..9cb9b9a9 100644 --- a/ledger/conway/pparams.go +++ b/ledger/conway/pparams.go @@ -211,6 +211,9 @@ func (p *ConwayProtocolParameters) Update( } func (p *ConwayProtocolParameters) UpdateFromGenesis(genesis *ConwayGenesis) { + if genesis == nil { + return + } p.MinCommitteeSize = genesis.MinCommitteeSize p.CommitteeTermLimit = genesis.CommitteeTermLimit p.GovActionValidityPeriod = genesis.GovActionValidityPeriod diff --git a/ledger/mary/pparams_test.go b/ledger/mary/pparams_test.go index aa096824..d4577432 100644 --- a/ledger/mary/pparams_test.go +++ b/ledger/mary/pparams_test.go @@ -145,6 +145,10 @@ func TestMaryUtxorpc(t *testing.T) { result := inputParams.Utxorpc() if !reflect.DeepEqual(result, expectedUtxorpc) { - t.Fatalf("Utxorpc() test failed for Mary:\nExpected: %#v\nGot: %#v", expectedUtxorpc, result) + t.Fatalf( + "Utxorpc() test failed for Mary:\nExpected: %#v\nGot: %#v", + expectedUtxorpc, + result, + ) } } diff --git a/ledger/shelley/pparams.go b/ledger/shelley/pparams.go index 12722f0a..12111ee7 100644 --- a/ledger/shelley/pparams.go +++ b/ledger/shelley/pparams.go @@ -98,6 +98,9 @@ func (p *ShelleyProtocolParameters) Update( } func (p *ShelleyProtocolParameters) UpdateFromGenesis(genesis *ShelleyGenesis) { + if genesis == nil { + return + } genesisParams := genesis.ProtocolParameters p.MinFeeA = genesisParams.MinFeeA p.MinFeeB = genesisParams.MinFeeB diff --git a/ledger/shelley/pparams_test.go b/ledger/shelley/pparams_test.go index aba6c5ea..58cdba2a 100644 --- a/ledger/shelley/pparams_test.go +++ b/ledger/shelley/pparams_test.go @@ -154,6 +154,10 @@ func TestShelleyUtxorpc(t *testing.T) { result := inputParams.Utxorpc() if !reflect.DeepEqual(result, expectedUtxorpc) { - t.Fatalf("Utxorpc() test failed for Shelley:\nExpected: %#v\nGot: %#v", expectedUtxorpc, result) + t.Fatalf( + "Utxorpc() test failed for Shelley:\nExpected: %#v\nGot: %#v", + expectedUtxorpc, + result, + ) } } diff --git a/ledger/verify_block.go b/ledger/verify_block.go index ba1f1813..680b1fcf 100644 --- a/ledger/verify_block.go +++ b/ledger/verify_block.go @@ -52,6 +52,11 @@ func VerifyBlock(block BlockHexCbor) (error, bool, string, uint64, uint64) { headerUnmarshalError.Error(), ), false, "", 0, 0 } + if header == nil { + return fmt.Errorf( + "VerifyBlock: header returned empty", + ), false, "", 0, 0 + } isKesValid, errKes := VerifyKes(header, slotPerKesPeriod) if errKes != nil { return fmt.Errorf( diff --git a/protocol/handshake/server.go b/protocol/handshake/server.go index b745b2a5..d4d363fe 100644 --- a/protocol/handshake/server.go +++ b/protocol/handshake/server.go @@ -129,6 +129,21 @@ func (s *Server) handleProposeVersions(msg protocol.Message) error { // Decode protocol parameters for selected version versionInfo := protocol.GetProtocolVersion(proposedVersion) versionData := s.config.ProtocolVersionMap[proposedVersion] + if versionData == nil { + msgRefuse := NewMsgRefuse( + []any{ + RefuseReasonDecodeError, + proposedVersion, + fmt.Errorf("handshake failed: refused due to empty version data"), + }, + ) + if err := s.SendMessage(msgRefuse); err != nil { + return err + } + return fmt.Errorf( + "handshake failed: refused due to empty version data", + ) + } proposedVersionData, err := versionInfo.NewVersionDataFromCborFunc( msgProposeVersions.VersionMap[proposedVersion], ) @@ -148,6 +163,22 @@ func (s *Server) handleProposeVersions(msg protocol.Message) error { err, ) } + if proposedVersionData == nil { + msgRefuse := NewMsgRefuse( + []any{ + RefuseReasonDecodeError, + proposedVersion, + fmt.Errorf("handshake failed: refused due to empty version map"), + }, + ) + if err := s.SendMessage(msgRefuse); err != nil { + return err + } + return fmt.Errorf( + "handshake failed: refused due to empty version map", + ) + } + // Check network magic if proposedVersionData.NetworkMagic() != versionData.NetworkMagic() { errMsg := fmt.Sprintf( diff --git a/protocol/localstatequery/client.go b/protocol/localstatequery/client.go index 84968378..5efe3791 100644 --- a/protocol/localstatequery/client.go +++ b/protocol/localstatequery/client.go @@ -115,7 +115,11 @@ func (c *Client) Start() { func (c *Client) Acquire(point *common.Point) error { var msg string if point != nil { - msg = fmt.Sprintf("calling Acquire(point: {Slot: %d, Hash: %x})", point.Slot, point.Hash) + msg = fmt.Sprintf( + "calling Acquire(point: {Slot: %d, Hash: %x})", + point.Slot, + point.Hash, + ) } else { msg = "calling Acquire(point: latest)" }