Skip to content

Commit be43b70

Browse files
committed
merge @SpencerPark commit 710233f
From SpencerPark <[email protected]>: Lock writes to the sockets to fix a bug with interleaved publishes.
1 parent 633fa16 commit be43b70

File tree

2 files changed

+66
-38
lines changed

2 files changed

+66
-38
lines changed

kernel.go

Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,20 @@ type ConnectionInfo struct {
3636
IP string `json:"ip"`
3737
}
3838

39+
// Socket wraps a zmq socket with a lock which should be used to control write access.
40+
type Socket struct {
41+
Socket *zmq.Socket
42+
Lock *sync.Mutex
43+
}
44+
3945
// SocketGroup holds the sockets needed to communicate with the kernel,
4046
// and the key for message signing.
4147
type SocketGroup struct {
42-
ShellSocket *zmq.Socket
43-
ControlSocket *zmq.Socket
44-
StdinSocket *zmq.Socket
45-
IOPubSocket *zmq.Socket
46-
HBSocket *zmq.Socket
48+
ShellSocket Socket
49+
ControlSocket Socket
50+
StdinSocket Socket
51+
IOPubSocket Socket
52+
HBSocket Socket
4753
Key []byte
4854
}
4955

@@ -85,6 +91,13 @@ const (
8591
kernelIdle = "idle"
8692
)
8793

94+
// RunWithSocket invokes the `run` function after acquiring the `Socket.Lock` and releases the lock when done.
95+
func (s *Socket) RunWithSocket(run func(socket *zmq.Socket) error) error {
96+
s.Lock.Lock()
97+
defer s.Lock.Unlock()
98+
return run(s.Socket)
99+
}
100+
88101
// runKernel is the main entry point to start the kernel.
89102
func runKernel(connectionFile string) {
90103

@@ -127,9 +140,9 @@ func runKernel(connectionFile string) {
127140
// TODO gracefully shutdown the heartbeat handler on kernel shutdown by closing the chan returned by startHeartbeat.
128141

129142
poller := zmq.NewPoller()
130-
poller.Add(sockets.ShellSocket, zmq.POLLIN)
131-
poller.Add(sockets.StdinSocket, zmq.POLLIN)
132-
poller.Add(sockets.ControlSocket, zmq.POLLIN)
143+
poller.Add(sockets.ShellSocket.Socket, zmq.POLLIN)
144+
poller.Add(sockets.StdinSocket.Socket, zmq.POLLIN)
145+
poller.Add(sockets.ControlSocket.Socket, zmq.POLLIN)
133146

134147
// msgParts will store a received multipart message.
135148
var msgParts [][]byte
@@ -147,8 +160,8 @@ func runKernel(connectionFile string) {
147160
switch socket := item.Socket; socket {
148161

149162
// Handle shell messages.
150-
case sockets.ShellSocket:
151-
msgParts, err = sockets.ShellSocket.RecvMessageBytes(0)
163+
case sockets.ShellSocket.Socket:
164+
msgParts, err = sockets.ShellSocket.Socket.RecvMessageBytes(0)
152165
if err != nil {
153166
log.Println(err)
154167
}
@@ -162,12 +175,12 @@ func runKernel(connectionFile string) {
162175
handleShellMsg(ir, msgReceipt{msg, ids, sockets})
163176

164177
// TODO Handle stdin socket.
165-
case sockets.StdinSocket:
166-
sockets.StdinSocket.RecvMessageBytes(0)
178+
case sockets.StdinSocket.Socket:
179+
sockets.StdinSocket.Socket.RecvMessageBytes(0)
167180

168181
// Handle control messages.
169-
case sockets.ControlSocket:
170-
msgParts, err = sockets.ControlSocket.RecvMessageBytes(0)
182+
case sockets.ControlSocket.Socket:
183+
msgParts, err = sockets.ControlSocket.Socket.RecvMessageBytes(0)
171184
if err != nil {
172185
log.Println(err)
173186
return
@@ -200,46 +213,51 @@ func prepareSockets(connInfo ConnectionInfo) (SocketGroup, error) {
200213

201214
// Create the shell socket, a request-reply socket that may receive messages from multiple frontend for
202215
// code execution, introspection, auto-completion, etc.
203-
sg.ShellSocket, err = context.NewSocket(zmq.ROUTER)
216+
sg.ShellSocket.Socket, err = context.NewSocket(zmq.ROUTER)
217+
sg.ShellSocket.Lock = &sync.Mutex{}
204218
if err != nil {
205219
return sg, err
206220
}
207221

208222
// Create the control socket. This socket is a duplicate of the shell socket where messages on this channel
209223
// should jump ahead of queued messages on the shell socket.
210-
sg.ControlSocket, err = context.NewSocket(zmq.ROUTER)
224+
sg.ControlSocket.Socket, err = context.NewSocket(zmq.ROUTER)
225+
sg.ControlSocket.Lock = &sync.Mutex{}
211226
if err != nil {
212227
return sg, err
213228
}
214229

215230
// Create the stdin socket, a request-reply socket used to request user input from a front-end. This is analogous
216231
// to a standard input stream.
217-
sg.StdinSocket, err = context.NewSocket(zmq.ROUTER)
232+
sg.StdinSocket.Socket, err = context.NewSocket(zmq.ROUTER)
233+
sg.StdinSocket.Lock = &sync.Mutex{}
218234
if err != nil {
219235
return sg, err
220236
}
221237

222238
// Create the iopub socket, a publisher for broadcasting data like stdout/stderr output, displaying execution
223239
// results or errors, kernel status, etc. to connected subscribers.
224-
sg.IOPubSocket, err = context.NewSocket(zmq.PUB)
240+
sg.IOPubSocket.Socket, err = context.NewSocket(zmq.PUB)
241+
sg.IOPubSocket.Lock = &sync.Mutex{}
225242
if err != nil {
226243
return sg, err
227244
}
228245

229246
// Create the heartbeat socket, a request-reply socket that only allows alternating recv-send (request-reply)
230247
// calls. It should echo the byte strings it receives to let the requester know the kernel is still alive.
231-
sg.HBSocket, err = context.NewSocket(zmq.REP)
248+
sg.HBSocket.Socket, err = context.NewSocket(zmq.REP)
249+
sg.HBSocket.Lock = &sync.Mutex{}
232250
if err != nil {
233251
return sg, err
234252
}
235253

236254
// Bind the sockets.
237255
address := fmt.Sprintf("%v://%v:%%v", connInfo.Transport, connInfo.IP)
238-
sg.ShellSocket.Bind(fmt.Sprintf(address, connInfo.ShellPort))
239-
sg.ControlSocket.Bind(fmt.Sprintf(address, connInfo.ControlPort))
240-
sg.StdinSocket.Bind(fmt.Sprintf(address, connInfo.StdinPort))
241-
sg.IOPubSocket.Bind(fmt.Sprintf(address, connInfo.IOPubPort))
242-
sg.HBSocket.Bind(fmt.Sprintf(address, connInfo.HBPort))
256+
sg.ShellSocket.Socket.Bind(fmt.Sprintf(address, connInfo.ShellPort))
257+
sg.ControlSocket.Socket.Bind(fmt.Sprintf(address, connInfo.ControlPort))
258+
sg.StdinSocket.Socket.Bind(fmt.Sprintf(address, connInfo.StdinPort))
259+
sg.IOPubSocket.Socket.Bind(fmt.Sprintf(address, connInfo.IOPubPort))
260+
sg.HBSocket.Socket.Bind(fmt.Sprintf(address, connInfo.HBPort))
243261

244262
// Set the message signing key.
245263
sg.Key = []byte(connInfo.Key)
@@ -496,7 +514,7 @@ func handleShutdownRequest(receipt msgReceipt) {
496514
// startHeartbeat starts a go-routine for handling heartbeat ping messages sent over the given `hbSocket`. The `wg`'s
497515
// `Done` method is invoked after the thread is completely shutdown. To request a shutdown the returned `shutdown` channel
498516
// can be closed.
499-
func startHeartbeat(hbSocket *zmq.Socket, wg *sync.WaitGroup) (shutdown chan struct{}) {
517+
func startHeartbeat(hbSocket Socket, wg *sync.WaitGroup) (shutdown chan struct{}) {
500518
quit := make(chan struct{})
501519

502520
// Start the handler that will echo any received messages back to the sender.
@@ -506,7 +524,7 @@ func startHeartbeat(hbSocket *zmq.Socket, wg *sync.WaitGroup) (shutdown chan str
506524

507525
// Create a `Poller` to check for incoming messages.
508526
poller := zmq.NewPoller()
509-
poller.Add(hbSocket, zmq.POLLIN)
527+
poller.Add(hbSocket.Socket, zmq.POLLIN)
510528

511529
for {
512530
select {
@@ -521,16 +539,22 @@ func startHeartbeat(hbSocket *zmq.Socket, wg *sync.WaitGroup) (shutdown chan str
521539

522540
// If there is at least 1 message waiting then echo it.
523541
if len(pingEvents) > 0 {
524-
// Read a message from the heartbeat channel as a simple byte string.
525-
pingMsg, err := hbSocket.RecvBytes(0)
526-
if err != nil {
527-
log.Fatalf("Error reading heartbeat ping bytes: %v\n", err)
528-
}
529-
530-
// Send the received byte string back to let the front-end know that the kernel is alive.
531-
if _, err = hbSocket.SendBytes(pingMsg, 0); err != nil {
532-
log.Printf("Error sending heartbeat pong bytes: %b\n", err)
533-
}
542+
hbSocket.RunWithSocket(func(echo *zmq.Socket) error {
543+
// Read a message from the heartbeat channel as a simple byte string.
544+
pingMsg, err := echo.RecvBytes(0)
545+
if err != nil {
546+
log.Fatalf("Error reading heartbeat ping bytes: %v\n", err)
547+
return err
548+
}
549+
550+
// Send the received byte string back to let the front-end know that the kernel is alive.
551+
if _, err = echo.SendBytes(pingMsg, 0); err != nil {
552+
log.Printf("Error sending heartbeat pong bytes: %b\n", err)
553+
return err
554+
}
555+
556+
return nil
557+
})
534558
}
535559
}
536560
}

messages.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ func (receipt *msgReceipt) Publish(msgType string, content interface{}) error {
189189
}
190190

191191
msg.Content = content
192-
return receipt.SendResponse(receipt.Sockets.IOPubSocket, msg)
192+
return receipt.Sockets.IOPubSocket.RunWithSocket(func(iopub *zmq.Socket) error {
193+
return receipt.SendResponse(iopub, msg)
194+
})
193195
}
194196

195197
// Reply creates a new ComposedMsg and sends it back to the return identities over the
@@ -202,7 +204,9 @@ func (receipt *msgReceipt) Reply(msgType string, content interface{}) error {
202204
}
203205

204206
msg.Content = content
205-
return receipt.SendResponse(receipt.Sockets.ShellSocket, msg)
207+
return receipt.Sockets.ShellSocket.RunWithSocket(func(shell *zmq.Socket) error {
208+
return receipt.SendResponse(shell, msg)
209+
})
206210
}
207211

208212
// newTextMIMEDataBundle creates a bundledMIMEData that only contains a text representation described

0 commit comments

Comments
 (0)