Skip to content

Commit c111882

Browse files
authored
feat: support for restarting server protocols (#611)
This allows a connected peer to start a protocol, stop it, and later start it again within the same connection Fixes #452
1 parent 99f51e0 commit c111882

File tree

6 files changed

+110
-34
lines changed

6 files changed

+110
-34
lines changed

muxer/muxer.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,26 @@ func (m *Muxer) RegisterProtocol(
178178
return senderChan, receiverChan, m.doneChan
179179
}
180180

181+
func (m *Muxer) UnregisterProtocol(
182+
protocolId uint16,
183+
protocolRole ProtocolRole,
184+
) {
185+
m.protocolReceiversMutex.Lock()
186+
protocolRoles, ok := m.protocolReceivers[protocolId]
187+
if !ok {
188+
return
189+
}
190+
recvChan, ok := protocolRoles[protocolRole]
191+
if !ok {
192+
return
193+
}
194+
// Signal shutdown to protocol
195+
close(recvChan)
196+
// Remove mapping
197+
delete(protocolRoles, protocolRole)
198+
m.protocolReceiversMutex.Unlock()
199+
}
200+
181201
// Send takes a populated Segment and writes it to the connection. A mutex is used to prevent more than
182202
// one protocol from sending at once
183203
func (m *Muxer) Send(msg *Segment) error {

protocol/blockfetch/server.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,30 +25,37 @@ type Server struct {
2525
*protocol.Protocol
2626
config *Config
2727
callbackContext CallbackContext
28+
protoOptions protocol.ProtocolOptions
2829
}
2930

3031
func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
3132
s := &Server{
3233
config: cfg,
34+
// Save this for re-use later
35+
protoOptions: protoOptions,
3336
}
3437
s.callbackContext = CallbackContext{
3538
Server: s,
3639
ConnectionId: protoOptions.ConnectionId,
3740
}
41+
s.initProtocol()
42+
return s
43+
}
44+
45+
func (s *Server) initProtocol() {
3846
protoConfig := protocol.ProtocolConfig{
3947
Name: ProtocolName,
4048
ProtocolId: ProtocolId,
41-
Muxer: protoOptions.Muxer,
42-
ErrorChan: protoOptions.ErrorChan,
43-
Mode: protoOptions.Mode,
49+
Muxer: s.protoOptions.Muxer,
50+
ErrorChan: s.protoOptions.ErrorChan,
51+
Mode: s.protoOptions.Mode,
4452
Role: protocol.ProtocolRoleServer,
4553
MessageHandlerFunc: s.messageHandler,
4654
MessageFromCborFunc: NewMsgFromCbor,
4755
StateMap: StateMap,
4856
InitialState: StateIdle,
4957
}
5058
s.Protocol = protocol.New(protoConfig)
51-
return s
5259
}
5360

5461
func (s *Server) NoBlocks() error {
@@ -107,5 +114,9 @@ func (s *Server) handleRequestRange(msg protocol.Message) error {
107114
}
108115

109116
func (s *Server) handleClientDone() error {
117+
// Restart protocol
118+
s.Protocol.Stop()
119+
s.initProtocol()
120+
s.Protocol.Start()
110121
return nil
111122
}

protocol/chainsync/server.go

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,40 +27,49 @@ type Server struct {
2727
*protocol.Protocol
2828
config *Config
2929
callbackContext CallbackContext
30+
protoOptions protocol.ProtocolOptions
31+
stateContext any
3032
}
3133

3234
// NewServer returns a new ChainSync server object
3335
func NewServer(stateContext interface{}, protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
34-
// Use node-to-client protocol ID
35-
ProtocolId := ProtocolIdNtC
36-
msgFromCborFunc := NewMsgFromCborNtC
37-
if protoOptions.Mode == protocol.ProtocolModeNodeToNode {
38-
// Use node-to-node protocol ID
39-
ProtocolId = ProtocolIdNtN
40-
msgFromCborFunc = NewMsgFromCborNtN
41-
}
4236
s := &Server{
4337
config: cfg,
38+
// Save these for re-use later
39+
protoOptions: protoOptions,
40+
stateContext: stateContext,
4441
}
4542
s.callbackContext = CallbackContext{
4643
Server: s,
4744
ConnectionId: protoOptions.ConnectionId,
4845
}
46+
s.initProtocol()
47+
return s
48+
}
49+
50+
func (s *Server) initProtocol() {
51+
// Use node-to-client protocol ID
52+
ProtocolId := ProtocolIdNtC
53+
msgFromCborFunc := NewMsgFromCborNtC
54+
if s.protoOptions.Mode == protocol.ProtocolModeNodeToNode {
55+
// Use node-to-node protocol ID
56+
ProtocolId = ProtocolIdNtN
57+
msgFromCborFunc = NewMsgFromCborNtN
58+
}
4959
protoConfig := protocol.ProtocolConfig{
5060
Name: ProtocolName,
5161
ProtocolId: ProtocolId,
52-
Muxer: protoOptions.Muxer,
53-
ErrorChan: protoOptions.ErrorChan,
54-
Mode: protoOptions.Mode,
62+
Muxer: s.protoOptions.Muxer,
63+
ErrorChan: s.protoOptions.ErrorChan,
64+
Mode: s.protoOptions.Mode,
5565
Role: protocol.ProtocolRoleServer,
5666
MessageHandlerFunc: s.messageHandler,
5767
MessageFromCborFunc: msgFromCborFunc,
5868
StateMap: StateMap,
59-
StateContext: stateContext,
69+
StateContext: s.stateContext,
6070
InitialState: stateIdle,
6171
}
6272
s.Protocol = protocol.New(protoConfig)
63-
return s
6473
}
6574

6675
func (s *Server) RollBackward(point common.Point, tip Tip) error {
@@ -147,5 +156,9 @@ func (s *Server) handleFindIntersect(msg protocol.Message) error {
147156
}
148157

149158
func (s *Server) handleDone() error {
159+
// Restart protocol
160+
s.Protocol.Stop()
161+
s.initProtocol()
162+
s.Protocol.Start()
150163
return nil
151164
}

protocol/peersharing/server.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,38 @@ type Server struct {
2525
*protocol.Protocol
2626
config *Config
2727
callbackContext CallbackContext
28+
protoOptions protocol.ProtocolOptions
2829
}
2930

3031
// NewServer returns a new PeerSharing server object
3132
func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
3233
s := &Server{
3334
config: cfg,
35+
// Save this for re-use later
36+
protoOptions: protoOptions,
3437
}
3538
s.callbackContext = CallbackContext{
3639
Server: s,
3740
ConnectionId: protoOptions.ConnectionId,
3841
}
42+
s.initProtocol()
43+
return s
44+
}
45+
46+
func (s *Server) initProtocol() {
3947
protoConfig := protocol.ProtocolConfig{
4048
Name: ProtocolName,
4149
ProtocolId: ProtocolId,
42-
Muxer: protoOptions.Muxer,
43-
ErrorChan: protoOptions.ErrorChan,
44-
Mode: protoOptions.Mode,
50+
Muxer: s.protoOptions.Muxer,
51+
ErrorChan: s.protoOptions.ErrorChan,
52+
Mode: s.protoOptions.Mode,
4553
Role: protocol.ProtocolRoleServer,
4654
MessageHandlerFunc: s.handleMessage,
4755
MessageFromCborFunc: NewMsgFromCbor,
4856
StateMap: StateMap,
4957
InitialState: stateIdle,
5058
}
5159
s.Protocol = protocol.New(protoConfig)
52-
return s
5360
}
5461

5562
func (s *Server) handleMessage(msg protocol.Message) error {
@@ -88,5 +95,9 @@ func (s *Server) handleShareRequest(msg protocol.Message) error {
8895
}
8996

9097
func (s *Server) handleDone(msg protocol.Message) error {
98+
// Restart protocol
99+
s.Protocol.Stop()
100+
s.initProtocol()
101+
s.Protocol.Start()
91102
return nil
92103
}

protocol/protocol.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ type Protocol struct {
4444
sendReadyChan chan bool
4545
stateTransitionChan chan<- protocolStateTransition
4646
onceStart sync.Once
47+
onceStop sync.Once
4748
}
4849

4950
// ProtocolConfig provides the configuration for Protocol
@@ -147,6 +148,21 @@ func (p *Protocol) Start() {
147148
})
148149
}
149150

151+
// Stop shuts down the mini-protocol
152+
func (p *Protocol) Stop() {
153+
p.onceStop.Do(func() {
154+
// Unregister protocol from muxer
155+
muxerProtocolRole := muxer.ProtocolRoleInitiator
156+
if p.config.Role == ProtocolRoleServer {
157+
muxerProtocolRole = muxer.ProtocolRoleResponder
158+
}
159+
p.config.Muxer.RegisterProtocol(
160+
p.config.ProtocolId,
161+
muxerProtocolRole,
162+
)
163+
})
164+
}
165+
150166
// Mode returns the protocol mode
151167
func (p *Protocol) Mode() ProtocolMode {
152168
return p.config.Mode

protocol/txsubmission/server.go

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ type Server struct {
2626
*protocol.Protocol
2727
config *Config
2828
callbackContext CallbackContext
29+
protoOptions protocol.ProtocolOptions
2930
ackCount int
30-
stateDone bool
3131
requestTxIdsResultChan chan []TxIdAndSize
3232
requestTxsResultChan chan []TxBody
3333
onceStart sync.Once
@@ -36,28 +36,34 @@ type Server struct {
3636
// NewServer returns a new TxSubmission server object
3737
func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
3838
s := &Server{
39-
config: cfg,
39+
config: cfg,
40+
// Save this for re-use later
41+
protoOptions: protoOptions,
4042
requestTxIdsResultChan: make(chan []TxIdAndSize),
4143
requestTxsResultChan: make(chan []TxBody),
4244
}
4345
s.callbackContext = CallbackContext{
4446
Server: s,
4547
ConnectionId: protoOptions.ConnectionId,
4648
}
49+
s.initProtocol()
50+
return s
51+
}
52+
53+
func (s *Server) initProtocol() {
4754
protoConfig := protocol.ProtocolConfig{
4855
Name: ProtocolName,
4956
ProtocolId: ProtocolId,
50-
Muxer: protoOptions.Muxer,
51-
ErrorChan: protoOptions.ErrorChan,
52-
Mode: protoOptions.Mode,
57+
Muxer: s.protoOptions.Muxer,
58+
ErrorChan: s.protoOptions.ErrorChan,
59+
Mode: s.protoOptions.Mode,
5360
Role: protocol.ProtocolRoleServer,
5461
MessageHandlerFunc: s.messageHandler,
5562
MessageFromCborFunc: NewMsgFromCbor,
5663
StateMap: StateMap,
5764
InitialState: stateInit,
5865
}
5966
s.Protocol = protocol.New(protoConfig)
60-
return s
6167
}
6268

6369
func (s *Server) Start() {
@@ -98,9 +104,6 @@ func (s *Server) RequestTxIds(
98104
blocking bool,
99105
reqCount int,
100106
) ([]TxIdAndSize, error) {
101-
if s.stateDone {
102-
return nil, protocol.ProtocolShuttingDownError
103-
}
104107
msg := NewMsgRequestTxIds(blocking, uint16(s.ackCount), uint16(reqCount))
105108
if err := s.SendMessage(msg); err != nil {
106109
return nil, err
@@ -117,9 +120,6 @@ func (s *Server) RequestTxIds(
117120

118121
// RequestTxs requests the content of the requested TX identifiers from the remote node's mempool
119122
func (s *Server) RequestTxs(txIds []TxId) ([]TxBody, error) {
120-
if s.stateDone {
121-
return nil, protocol.ProtocolShuttingDownError
122-
}
123123
msg := NewMsgRequestTxs(txIds)
124124
if err := s.SendMessage(msg); err != nil {
125125
return nil, err
@@ -147,7 +147,12 @@ func (s *Server) handleReplyTxs(msg protocol.Message) error {
147147
}
148148

149149
func (s *Server) handleDone() error {
150-
s.stateDone = true
150+
// Restart protocol
151+
s.Protocol.Stop()
152+
s.initProtocol()
153+
s.requestTxIdsResultChan = make(chan []TxIdAndSize)
154+
s.requestTxsResultChan = make(chan []TxBody)
155+
s.Protocol.Start()
151156
return nil
152157
}
153158

0 commit comments

Comments
 (0)