Skip to content

Commit a98d6dd

Browse files
authored
Merge pull request #183 from cloudstruct/feat/protocol-resource-cleanup
feat: cleanup protocol resources on shutdown
2 parents 1ec6209 + 4432835 commit a98d6dd

File tree

10 files changed

+213
-39
lines changed

10 files changed

+213
-39
lines changed

Makefile

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
1-
BINARY=go-ouroboros-network
2-
31
# Determine root directory
42
ROOT_DIR=$(shell dirname $(realpath $(firstword $(MAKEFILE_LIST))))
53

64
# Gather all .go files for use in dependencies below
75
GO_FILES=$(shell find $(ROOT_DIR) -name '*.go')
86

9-
# Build our program binary
10-
# Depends on GO_FILES to determine when rebuild is needed
11-
$(BINARY): $(GO_FILES)
12-
# Needed to fetch new dependencies and add them to go.mod
13-
go mod tidy
14-
go build -o $(BINARY) ./cmd/$(BINARY)
7+
# Gather list of expected binaries
8+
BINARIES=$(shell cd $(ROOT_DIR)/cmd && ls -1)
159

16-
.PHONY: build clean test
10+
.PHONY: build mod-tidy clean test
1711

1812
# Alias for building program binary
19-
build: $(BINARY)
13+
build: $(BINARIES)
14+
15+
mod-tidy:
16+
# Needed to fetch new dependencies and add them to go.mod
17+
go mod tidy
2018

2119
clean:
22-
rm -f $(BINARY)
20+
rm -f $(BINARIES)
2321

2422
test:
2523
go test -v ./...
24+
25+
# Build our program binaries
26+
# Depends on GO_FILES to determine when rebuild is needed
27+
$(BINARIES): mod-tidy $(GO_FILES)
28+
go build -o $(@) ./cmd/$(@)

cmd/go-ouroboros-network/main.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ func main() {
7878
testServer(f)
7979
case "query":
8080
testQuery(f)
81+
case "mem-usage":
82+
testMemUsage(f)
8183
default:
8284
fmt.Printf("Unknown subcommand: %s\n", f.flagset.Arg(0))
8385
os.Exit(1)
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package main
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
"log"
7+
"net/http"
8+
_ "net/http/pprof"
9+
"os"
10+
"runtime"
11+
"runtime/pprof"
12+
"time"
13+
14+
ouroboros "github.com/cloudstruct/go-ouroboros-network"
15+
)
16+
17+
type memUsageFlags struct {
18+
flagset *flag.FlagSet
19+
startEra string
20+
tip bool
21+
debugPort int
22+
}
23+
24+
func newMemUsageFlags() *memUsageFlags {
25+
f := &memUsageFlags{
26+
flagset: flag.NewFlagSet("mem-usage", flag.ExitOnError),
27+
}
28+
f.flagset.StringVar(&f.startEra, "start-era", "genesis", "era which to start chain-sync at")
29+
f.flagset.BoolVar(&f.tip, "tip", false, "start chain-sync at current chain tip")
30+
f.flagset.IntVar(&f.debugPort, "debug-port", 8080, "pprof port")
31+
return f
32+
}
33+
34+
func testMemUsage(f *globalFlags) {
35+
memUsageFlags := newMemUsageFlags()
36+
err := memUsageFlags.flagset.Parse(f.flagset.Args()[1:])
37+
if err != nil {
38+
fmt.Printf("failed to parse subcommand args: %s\n", err)
39+
os.Exit(1)
40+
}
41+
42+
// Start pprof listener
43+
log.Printf("Starting pprof listener on http://0.0.0.0:%d/debug/pprof\n", memUsageFlags.debugPort)
44+
go func() {
45+
log.Println(http.ListenAndServe(fmt.Sprintf(":%d", memUsageFlags.debugPort), nil))
46+
}()
47+
48+
for i := 0; i < 10; i++ {
49+
showMemoryStats("open")
50+
51+
conn := createClientConnection(f)
52+
errorChan := make(chan error)
53+
go func() {
54+
for {
55+
err, ok := <-errorChan
56+
if !ok {
57+
return
58+
}
59+
fmt.Printf("ERROR: %s\n", err)
60+
os.Exit(1)
61+
}
62+
}()
63+
o, err := ouroboros.New(
64+
ouroboros.WithConnection(conn),
65+
ouroboros.WithNetworkMagic(uint32(f.networkMagic)),
66+
ouroboros.WithErrorChan(errorChan),
67+
ouroboros.WithNodeToNode(f.ntnProto),
68+
ouroboros.WithKeepAlive(true),
69+
)
70+
if err != nil {
71+
fmt.Printf("ERROR: %s\n", err)
72+
os.Exit(1)
73+
}
74+
o.ChainSync.Client.Start()
75+
76+
tip, err := o.ChainSync.Client.GetCurrentTip()
77+
if err != nil {
78+
fmt.Printf("ERROR: %s\n", err)
79+
os.Exit(1)
80+
}
81+
82+
log.Printf("tip: slot = %d, hash = %x\n", tip.Point.Slot, tip.Point.Hash)
83+
84+
if err := o.Close(); err != nil {
85+
fmt.Printf("ERROR: %s\n", err)
86+
}
87+
88+
showMemoryStats("close")
89+
90+
time.Sleep(5 * time.Second)
91+
92+
runtime.GC()
93+
94+
showMemoryStats("after GC")
95+
}
96+
97+
if err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1); err != nil {
98+
fmt.Printf("ERROR: %s\n", err)
99+
os.Exit(1)
100+
}
101+
102+
fmt.Printf("waiting forever")
103+
select {}
104+
}
105+
106+
func showMemoryStats(tag string) {
107+
var m runtime.MemStats
108+
runtime.ReadMemStats(&m)
109+
log.Printf("[%s] HeapAlloc: %dKiB\n", tag, m.HeapAlloc/1024)
110+
}

muxer/muxer.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func (m *Muxer) sendError(err error) {
9494
m.Stop()
9595
}
9696

97-
func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segment) {
97+
func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segment, chan bool) {
9898
// Generate channels
9999
senderChan := make(chan *Segment, 10)
100100
receiverChan := make(chan *Segment, 10)
@@ -118,7 +118,7 @@ func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segmen
118118
}
119119
}
120120
}()
121-
return senderChan, receiverChan
121+
return senderChan, receiverChan, m.doneChan
122122
}
123123

124124
func (m *Muxer) Send(msg *Segment) error {

ouroboros.go

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -132,20 +132,24 @@ func (o *Ouroboros) setupConnection() error {
132132
o.muxer = muxer.New(o.conn)
133133
// Start Goroutine to pass along errors from the muxer
134134
go func() {
135-
err, ok := <-o.muxer.ErrorChan
136-
// Break out of goroutine if muxer's error channel is closed
137-
if !ok {
135+
select {
136+
case <-o.doneChan:
138137
return
138+
case err, ok := <-o.muxer.ErrorChan:
139+
// Break out of goroutine if muxer's error channel is closed
140+
if !ok {
141+
return
142+
}
143+
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
144+
// Return a bare io.EOF error if error is EOF/ErrUnexpectedEOF
145+
o.ErrorChan <- io.EOF
146+
} else {
147+
// Wrap error message to denote it comes from the muxer
148+
o.ErrorChan <- fmt.Errorf("muxer error: %s", err)
149+
}
150+
// Close connection on muxer errors
151+
o.Close()
139152
}
140-
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
141-
// Return a bare io.EOF error if error is EOF/ErrUnexpectedEOF
142-
o.ErrorChan <- io.EOF
143-
} else {
144-
// Wrap error message to denote it comes from the muxer
145-
o.ErrorChan <- fmt.Errorf("muxer error: %s", err)
146-
}
147-
// Close connection on muxer errors
148-
o.Close()
149153
}()
150154
protoOptions := protocol.ProtocolOptions{
151155
Muxer: o.muxer,

protocol/chainsync/client.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
6363
InitialState: STATE_IDLE,
6464
}
6565
c.Protocol = protocol.New(protoConfig)
66+
// Start goroutine to cleanup resources on protocol shutdown
67+
go func() {
68+
<-c.Protocol.DoneChan()
69+
close(c.intersectResultChan)
70+
close(c.readyForNextBlockChan)
71+
close(c.currentTipChan)
72+
}()
6673
return c
6774
}
6875

protocol/keepalive/client.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
4040
InitialState: STATE_CLIENT,
4141
}
4242
c.Protocol = protocol.New(protoConfig)
43+
// Start goroutine to cleanup resources on protocol shutdown
44+
go func() {
45+
<-c.Protocol.DoneChan()
46+
if c.timer != nil {
47+
// Stop timer and drain channel
48+
if ok := c.timer.Stop(); !ok {
49+
<-c.timer.C
50+
}
51+
}
52+
}()
4353
return c
4454
}
4555

protocol/localstatequery/client.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
6565
c.enableGetRewardInfoPoolsBlock = true
6666
}
6767
c.Protocol = protocol.New(protoConfig)
68+
// Start goroutine to cleanup resources on protocol shutdown
69+
go func() {
70+
<-c.Protocol.DoneChan()
71+
close(c.queryResultChan)
72+
close(c.acquireResultChan)
73+
}()
6874
return c
6975
}
7076

protocol/localtxsubmission/client.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
4343
InitialState: STATE_IDLE,
4444
}
4545
c.Protocol = protocol.New(protoConfig)
46+
// Start goroutine to cleanup resources on protocol shutdown
47+
go func() {
48+
<-c.Protocol.DoneChan()
49+
close(c.submitResultChan)
50+
}()
4651
return c
4752
}
4853

protocol/protocol.go

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ type Protocol struct {
2222
config ProtocolConfig
2323
muxerSendChan chan *muxer.Segment
2424
muxerRecvChan chan *muxer.Segment
25+
muxerDoneChan chan bool
2526
state State
2627
stateMutex sync.Mutex
2728
recvBuffer *bytes.Buffer
@@ -77,21 +78,29 @@ type MessageFromCborFunc func(uint, []byte) (Message, error)
7778

7879
func New(config ProtocolConfig) *Protocol {
7980
p := &Protocol{
80-
config: config,
81+
config: config,
82+
doneChan: make(chan bool),
8183
}
8284
return p
8385
}
8486

8587
func (p *Protocol) Start() {
8688
// Register protocol with muxer
87-
p.muxerSendChan, p.muxerRecvChan = p.config.Muxer.RegisterProtocol(p.config.ProtocolId)
89+
p.muxerSendChan, p.muxerRecvChan, p.muxerDoneChan = p.config.Muxer.RegisterProtocol(p.config.ProtocolId)
8890
// Create buffers and channels
8991
p.recvBuffer = bytes.NewBuffer(nil)
9092
p.sendQueueChan = make(chan Message, 50)
9193
p.sendStateQueueChan = make(chan Message, 50)
9294
p.recvReadyChan = make(chan bool, 1)
9395
p.sendReadyChan = make(chan bool, 1)
94-
p.doneChan = make(chan bool)
96+
// Start goroutine to cleanup when shutting down
97+
go func() {
98+
<-p.doneChan
99+
close(p.sendQueueChan)
100+
close(p.sendStateQueueChan)
101+
close(p.recvReadyChan)
102+
close(p.sendReadyChan)
103+
}()
95104
// Set initial state
96105
p.setState(p.config.InitialState)
97106
// Start our send and receive Goroutines
@@ -107,6 +116,10 @@ func (p *Protocol) Role() ProtocolRole {
107116
return p.config.Role
108117
}
109118

119+
func (p *Protocol) DoneChan() chan bool {
120+
return p.doneChan
121+
}
122+
110123
func (p *Protocol) SendMessage(msg Message) error {
111124
p.sendQueueChan <- msg
112125
return nil
@@ -122,14 +135,14 @@ func (p *Protocol) sendLoop() {
122135
var err error
123136
for {
124137
select {
125-
case <-p.sendReadyChan:
126-
// We are ready to send based on state map
127138
case <-p.doneChan:
128139
// We are responsible for closing this channel as the sender, even through it
129140
// was created by the muxer
130141
close(p.muxerSendChan)
131142
// Break out of send loop if we're shutting down
132143
return
144+
case <-p.sendReadyChan:
145+
// We are ready to send based on state map
133146
}
134147
// Lock the state to prevent collisions
135148
p.stateMutex.Lock()
@@ -155,7 +168,11 @@ func (p *Protocol) sendLoop() {
155168
msgCount := 0
156169
for {
157170
// Get next message from send queue
158-
msg := <-p.sendQueueChan
171+
msg, ok := <-p.sendQueueChan
172+
if !ok {
173+
// We're shutting down
174+
return
175+
}
159176
msgCount = msgCount + 1
160177
// Write the message into the send state queue if we already have a new state
161178
if setNewState {
@@ -234,20 +251,29 @@ func (p *Protocol) recvLoop() {
234251
// Don't grab the next segment from the muxer if we still have data in the buffer
235252
if !leftoverData {
236253
// Wait for segment
237-
segment, ok := <-p.muxerRecvChan
238-
// Break out of receive loop if channel is closed
239-
if !ok {
254+
select {
255+
case <-p.muxerDoneChan:
240256
close(p.doneChan)
241257
return
258+
case segment, ok := <-p.muxerRecvChan:
259+
if !ok {
260+
close(p.doneChan)
261+
return
262+
}
263+
// Add segment payload to buffer
264+
p.recvBuffer.Write(segment.Payload)
265+
// Save whether it's a response
266+
isResponse = segment.IsResponse()
242267
}
243-
// Add segment payload to buffer
244-
p.recvBuffer.Write(segment.Payload)
245-
// Save whether it's a response
246-
isResponse = segment.IsResponse()
247268
}
248269
leftoverData = false
249270
// Wait until ready to receive based on state map
250-
<-p.recvReadyChan
271+
select {
272+
case <-p.muxerDoneChan:
273+
close(p.doneChan)
274+
return
275+
case <-p.recvReadyChan:
276+
}
251277
// Decode message into generic list until we can determine what type of message it is.
252278
// This also lets us determine how many bytes the message is. We use RawMessage here to
253279
// avoid parsing things that we may not be able to parse
@@ -321,6 +347,7 @@ func (p *Protocol) getNewState(msg Message) (State, error) {
321347
func (p *Protocol) setState(state State) {
322348
// Disable any previous state transition timer
323349
if p.stateTransitionTimer != nil {
350+
// Stop timer and drain channel
324351
if !p.stateTransitionTimer.Stop() {
325352
<-p.stateTransitionTimer.C
326353
}

0 commit comments

Comments
 (0)