Skip to content

Commit fa5f78e

Browse files
committed
fix: utilize 'msg' parameter in handleRequestNext and pass it to callback
Signed-off-by: Ales Verbic <[email protected]>
1 parent 1cb4e76 commit fa5f78e

File tree

4 files changed

+100
-6
lines changed

4 files changed

+100
-6
lines changed

protocol/chainsync/chainsync.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ type RollBackwardFunc func(CallbackContext, common.Point, Tip) error
219219
type RollForwardFunc func(CallbackContext, uint, interface{}, Tip) error
220220

221221
type FindIntersectFunc func(CallbackContext, []common.Point) (common.Point, Tip, error)
222-
type RequestNextFunc func(CallbackContext) error
222+
type RequestNextFunc func(CallbackContext, *MsgRequestNext) error
223223

224224
// New returns a new ChainSync object
225225
func New(protoOptions protocol.ProtocolOptions, cfg *Config) *ChainSync {

protocol/chainsync/messages_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@ package chainsync
1717
import (
1818
"encoding/hex"
1919
"fmt"
20-
"github.com/blinklabs-io/gouroboros/cbor"
21-
"github.com/blinklabs-io/gouroboros/ledger"
22-
"github.com/blinklabs-io/gouroboros/protocol"
23-
"github.com/blinklabs-io/gouroboros/protocol/common"
2420
"os"
2521
"reflect"
2622
"strings"
2723
"testing"
24+
25+
"github.com/blinklabs-io/gouroboros/cbor"
26+
"github.com/blinklabs-io/gouroboros/ledger"
27+
"github.com/blinklabs-io/gouroboros/protocol"
28+
"github.com/blinklabs-io/gouroboros/protocol/common"
2829
)
2930

3031
type testDefinition struct {

protocol/chainsync/server.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,12 @@ func (s *Server) handleRequestNext(msg protocol.Message) error {
131131
"received chain-sync RequestNext message but no callback function is defined",
132132
)
133133
}
134-
return s.config.RequestNextFunc(s.callbackContext)
134+
msgRequestNext, ok := msg.(*MsgRequestNext)
135+
if !ok {
136+
return fmt.Errorf("expected MsgRequestNext, got %T", msg)
137+
}
138+
// Pass the message to the callback function
139+
return s.config.RequestNextFunc(s.callbackContext, msgRequestNext)
135140
}
136141

137142
func (s *Server) handleFindIntersect(msg protocol.Message) error {

protocol/chainsync/server_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright 2024 Blink Labs Software
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package chainsync
16+
17+
import (
18+
"fmt"
19+
"testing"
20+
21+
"github.com/stretchr/testify/assert"
22+
)
23+
24+
func TestHandleRequestNext_ValidMessage(t *testing.T) {
25+
called := false
26+
var receivedMsg *MsgRequestNext
27+
28+
server := &Server{
29+
config: &Config{
30+
RequestNextFunc: func(ctx CallbackContext, msg *MsgRequestNext) error {
31+
called = true
32+
receivedMsg = msg
33+
// Ensure that the CBOR data is not empty
34+
if len(msg.Cbor()) == 0 {
35+
return fmt.Errorf("expected non-empty CBOR data")
36+
}
37+
return nil
38+
},
39+
},
40+
callbackContext: CallbackContext{},
41+
}
42+
43+
msg := &MsgRequestNext{}
44+
// Fake CBOR data
45+
rawCborData := []byte{0x01, 0x02, 0x03}
46+
msg.SetCbor(rawCborData)
47+
48+
err := server.handleRequestNext(msg)
49+
50+
assert.NoError(t, err, "expected no error")
51+
assert.True(t, called, "expected RequestNextFunc to be called")
52+
assert.Equal(t, msg, receivedMsg, "expected received message to be the same as sent message")
53+
assert.Equal(t, rawCborData, receivedMsg.Cbor(), "expected raw CBOR data to be passed correctly")
54+
}
55+
56+
func TestHandleRequestNext_InvalidMessageType(t *testing.T) {
57+
server := &Server{
58+
config: &Config{
59+
RequestNextFunc: func(ctx CallbackContext, msg *MsgRequestNext) error {
60+
return nil
61+
},
62+
},
63+
callbackContext: CallbackContext{},
64+
}
65+
66+
msg := &MsgFindIntersect{}
67+
err := server.handleRequestNext(msg)
68+
expectedError := fmt.Sprintf("expected MsgRequestNext, got %T", msg)
69+
70+
assert.Error(t, err, "expected an error due to invalid message type")
71+
assert.EqualError(t, err, expectedError)
72+
}
73+
74+
func TestHandleRequestNext_NilCallback(t *testing.T) {
75+
server := &Server{
76+
config: &Config{
77+
RequestNextFunc: nil,
78+
},
79+
callbackContext: CallbackContext{},
80+
}
81+
82+
msg := &MsgRequestNext{}
83+
err := server.handleRequestNext(msg)
84+
expectedError := "received chain-sync RequestNext message but no callback function is defined"
85+
86+
assert.Error(t, err, "expected an error due to nil callback")
87+
assert.EqualError(t, err, expectedError)
88+
}

0 commit comments

Comments
 (0)