Skip to content

Commit 6ab8c69

Browse files
authored
Merge pull request #78 from cloudstruct/feature/close-function
feat: method to shutdown connection
2 parents 6dbfa11 + cee39d6 commit 6ab8c69

File tree

3 files changed

+122
-18
lines changed

3 files changed

+122
-18
lines changed

muxer/muxer.go

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ type Muxer struct {
2121
conn net.Conn
2222
sendMutex sync.Mutex
2323
startChan chan bool
24+
doneChan chan bool
2425
ErrorChan chan error
2526
protocolSenders map[uint16]chan *Segment
2627
protocolReceivers map[uint16]chan *Segment
@@ -30,6 +31,7 @@ func New(conn net.Conn) *Muxer {
3031
m := &Muxer{
3132
conn: conn,
3233
startChan: make(chan bool, 1),
34+
doneChan: make(chan bool),
3335
ErrorChan: make(chan error, 10),
3436
protocolSenders: make(map[uint16]chan *Segment),
3537
protocolReceivers: make(map[uint16]chan *Segment),
@@ -42,6 +44,37 @@ func (m *Muxer) Start() {
4244
m.startChan <- true
4345
}
4446

47+
func (m *Muxer) Stop() {
48+
// Immediately return if we're already shutting down
49+
select {
50+
case <-m.doneChan:
51+
return
52+
default:
53+
}
54+
// Close protocol receive channels
55+
// We rely on the individual mini-protocols to close the sender channel
56+
for _, recvChan := range m.protocolReceivers {
57+
close(recvChan)
58+
}
59+
// Close ErrorChan to signify to consumer that we're shutting down
60+
close(m.ErrorChan)
61+
// Close doneChan to signify that we're shutting down
62+
close(m.doneChan)
63+
}
64+
65+
func (m *Muxer) sendError(err error) {
66+
// Immediately return if we're already shutting down
67+
select {
68+
case <-m.doneChan:
69+
return
70+
default:
71+
}
72+
// Send error to consumer
73+
m.ErrorChan <- err
74+
// Stop the muxer on any error
75+
m.Stop()
76+
}
77+
4578
func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segment) {
4679
// Generate channels
4780
senderChan := make(chan *Segment, 10)
@@ -52,9 +85,17 @@ func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segmen
5285
// Start Goroutine to handle outbound messages
5386
go func() {
5487
for {
55-
msg := <-senderChan
56-
if err := m.Send(msg); err != nil {
57-
m.ErrorChan <- err
88+
select {
89+
case _, ok := <-m.doneChan:
90+
// doneChan has been closed, which means we're shutting down
91+
if !ok {
92+
return
93+
}
94+
case msg := <-senderChan:
95+
if err := m.Send(msg); err != nil {
96+
m.sendError(err)
97+
return
98+
}
5899
}
59100
}
60101
}()
@@ -81,9 +122,16 @@ func (m *Muxer) Send(msg *Segment) error {
81122
func (m *Muxer) readLoop() {
82123
started := false
83124
for {
125+
// Break out of read loop if we're shutting down
126+
select {
127+
case <-m.doneChan:
128+
return
129+
default:
130+
}
84131
header := SegmentHeader{}
85132
if err := binary.Read(m.conn, binary.BigEndian, &header); err != nil {
86-
m.ErrorChan <- err
133+
m.sendError(err)
134+
return
87135
}
88136
msg := &Segment{
89137
SegmentHeader: header,
@@ -92,24 +140,32 @@ func (m *Muxer) readLoop() {
92140
// We use ReadFull because it guarantees to read the expected number of bytes or
93141
// return an error
94142
if _, err := io.ReadFull(m.conn, msg.Payload); err != nil {
95-
m.ErrorChan <- err
96-
}
97-
// Wait until the muxer is started to process anything other than handshake messages
98-
if !started && msg.GetProtocolId() != PROTOCOL_HANDSHAKE {
99-
<-m.startChan
100-
started = true
143+
m.sendError(err)
144+
return
101145
}
102146
// Send message payload to proper receiver
103147
recvChan := m.protocolReceivers[msg.GetProtocolId()]
104148
if recvChan == nil {
105149
// Try the "unknown protocol" receiver if we didn't find an explicit one
106150
recvChan = m.protocolReceivers[PROTOCOL_UNKNOWN]
107151
if recvChan == nil {
108-
m.ErrorChan <- fmt.Errorf("received message for unknown protocol ID %d", msg.GetProtocolId())
152+
m.sendError(fmt.Errorf("received message for unknown protocol ID %d", msg.GetProtocolId()))
153+
return
109154
}
110155
}
111156
if recvChan != nil {
112157
recvChan <- msg
113158
}
159+
// Wait until the muxer is started to continue
160+
// We don't want to read more than one segment until the handshake is complete
161+
if !started {
162+
select {
163+
case <-m.doneChan:
164+
// Break out of read loop if we're shutting down
165+
return
166+
case <-m.startChan:
167+
started = true
168+
}
169+
}
114170
}
115171
}

ouroboros.go

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ouroboros
22

33
import (
4+
"fmt"
45
"github.com/cloudstruct/go-ouroboros-network/muxer"
56
"github.com/cloudstruct/go-ouroboros-network/protocol"
67
"github.com/cloudstruct/go-ouroboros-network/protocol/blockfetch"
@@ -21,6 +22,7 @@ type Ouroboros struct {
2122
handshakeComplete bool
2223
muxer *muxer.Muxer
2324
ErrorChan chan error
25+
protoErrorChan chan error
2426
sendKeepAlives bool
2527
delayMuxerStart bool
2628
// Mini-protocols
@@ -68,6 +70,7 @@ func New(options *OuroborosOptions) (*Ouroboros, error) {
6870
ErrorChan: options.ErrorChan,
6971
sendKeepAlives: options.SendKeepAlives,
7072
delayMuxerStart: options.DelayMuxerStart,
73+
protoErrorChan: make(chan error, 10),
7174
}
7275
if o.ErrorChan == nil {
7376
o.ErrorChan = make(chan error, 10)
@@ -98,16 +101,32 @@ func (o *Ouroboros) Dial(proto string, address string) error {
98101
return nil
99102
}
100103

104+
func (o *Ouroboros) Close() error {
105+
// Gracefully stop the muxer
106+
o.muxer.Stop()
107+
// Close the underlying connection
108+
if err := o.conn.Close(); err != nil {
109+
return err
110+
}
111+
return nil
112+
}
113+
101114
func (o *Ouroboros) setupConnection() error {
102115
o.muxer = muxer.New(o.conn)
103116
// Start Goroutine to pass along errors from the muxer
104117
go func() {
105-
err := <-o.muxer.ErrorChan
106-
o.ErrorChan <- err
118+
err, ok := <-o.muxer.ErrorChan
119+
// Break out of goroutine if muxer's error channel is closed
120+
if !ok {
121+
return
122+
}
123+
o.ErrorChan <- fmt.Errorf("muxer error: %s", err)
124+
// Close connection on muxer errors
125+
o.Close()
107126
}()
108127
protoOptions := protocol.ProtocolOptions{
109128
Muxer: o.muxer,
110-
ErrorChan: o.ErrorChan,
129+
ErrorChan: o.protoErrorChan,
111130
}
112131
var protoVersions []uint16
113132
if o.useNodeToNodeProto {
@@ -131,13 +150,26 @@ func (o *Ouroboros) setupConnection() error {
131150
return err
132151
}
133152
}
134-
o.handshakeComplete = <-o.Handshake.Finished
153+
// Wait for handshake completion or error
154+
select {
155+
case err := <-o.protoErrorChan:
156+
return err
157+
case finished := <-o.Handshake.Finished:
158+
o.handshakeComplete = finished
159+
}
135160
// Provide the negotiated protocol version to the various mini-protocols
136161
protoOptions.Version = o.Handshake.Version
137162
// Drop bit used to signify NtC protocol versions
138163
if protoOptions.Version > PROTOCOL_VERSION_NTC_FLAG {
139164
protoOptions.Version = protoOptions.Version - PROTOCOL_VERSION_NTC_FLAG
140165
}
166+
// Start Goroutine to pass along errors from the mini-protocols
167+
go func() {
168+
err := <-o.protoErrorChan
169+
o.ErrorChan <- fmt.Errorf("protocol error: %s", err)
170+
// Close connection on mini-protocol errors
171+
o.Close()
172+
}()
141173
// Configure the relevant mini-protocols
142174
if o.useNodeToNodeProto {
143175
versionNtN := GetProtocolVersionNtN(o.Handshake.Version)
@@ -160,6 +192,7 @@ func (o *Ouroboros) setupConnection() error {
160192
o.LocalStateQuery = localstatequery.New(protoOptions, o.localStateQueryCallbackConfig)
161193
}
162194
}
195+
// Start muxer
163196
if !o.delayMuxerStart {
164197
o.muxer.Start()
165198
}

protocol/protocol.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type Protocol struct {
2525
sendStateQueueChan chan Message
2626
recvReadyChan chan bool
2727
sendReadyChan chan bool
28+
doneChan chan bool
2829
}
2930

3031
type ProtocolConfig struct {
@@ -79,6 +80,7 @@ func New(config ProtocolConfig) *Protocol {
7980
sendStateQueueChan: make(chan Message, 50),
8081
recvReadyChan: make(chan bool, 1),
8182
sendReadyChan: make(chan bool, 1),
83+
doneChan: make(chan bool),
8284
}
8385
// Set initial state
8486
p.setState(config.InitialState)
@@ -110,8 +112,16 @@ func (p *Protocol) sendLoop() {
110112
var newState State
111113
var err error
112114
for {
113-
// Wait until ready to send based on state map
114-
<-p.sendReadyChan
115+
select {
116+
case <-p.sendReadyChan:
117+
// We are ready to send based on state map
118+
case <-p.doneChan:
119+
// We are responsible for closing this channel as the sender, even through it
120+
// was created by the muxer
121+
close(p.muxerSendChan)
122+
// Break out of send loop if we're shutting down
123+
return
124+
}
115125
// Lock the state to prevent collisions
116126
p.stateMutex.Lock()
117127
// Check for queued state changes from previous pipelined sends
@@ -212,7 +222,12 @@ func (p *Protocol) recvLoop() {
212222
// Don't grab the next segment from the muxer if we still have data in the buffer
213223
if !leftoverData {
214224
// Wait for segment
215-
segment := <-p.muxerRecvChan
225+
segment, ok := <-p.muxerRecvChan
226+
// Break out of receive loop if channel is closed
227+
if !ok {
228+
close(p.doneChan)
229+
return
230+
}
216231
// Add segment payload to buffer
217232
p.recvBuffer.Write(segment.Payload)
218233
// Save whether it's a response

0 commit comments

Comments
 (0)