@@ -36,14 +36,20 @@ type ConnectionInfo struct {
3636 IP string `json:"ip"`
3737}
3838
39+ // Socket wraps a zmq socket with a lock which should be used to control write access.
40+ type Socket struct {
41+ Socket * zmq.Socket
42+ Lock * sync.Mutex
43+ }
44+
3945// SocketGroup holds the sockets needed to communicate with the kernel,
4046// and the key for message signing.
4147type SocketGroup struct {
42- ShellSocket * zmq. Socket
43- ControlSocket * zmq. Socket
44- StdinSocket * zmq. Socket
45- IOPubSocket * zmq. Socket
46- HBSocket * zmq. Socket
48+ ShellSocket Socket
49+ ControlSocket Socket
50+ StdinSocket Socket
51+ IOPubSocket Socket
52+ HBSocket Socket
4753 Key []byte
4854}
4955
@@ -85,6 +91,13 @@ const (
8591 kernelIdle = "idle"
8692)
8793
94+ // RunWithSocket invokes the `run` function after acquiring the `Socket.Lock` and releases the lock when done.
95+ func (s * Socket ) RunWithSocket (run func (socket * zmq.Socket ) error ) error {
96+ s .Lock .Lock ()
97+ defer s .Lock .Unlock ()
98+ return run (s .Socket )
99+ }
100+
88101// runKernel is the main entry point to start the kernel.
89102func runKernel (connectionFile string ) {
90103
@@ -127,9 +140,9 @@ func runKernel(connectionFile string) {
127140 // TODO gracefully shutdown the heartbeat handler on kernel shutdown by closing the chan returned by startHeartbeat.
128141
129142 poller := zmq .NewPoller ()
130- poller .Add (sockets .ShellSocket , zmq .POLLIN )
131- poller .Add (sockets .StdinSocket , zmq .POLLIN )
132- poller .Add (sockets .ControlSocket , zmq .POLLIN )
143+ poller .Add (sockets .ShellSocket . Socket , zmq .POLLIN )
144+ poller .Add (sockets .StdinSocket . Socket , zmq .POLLIN )
145+ poller .Add (sockets .ControlSocket . Socket , zmq .POLLIN )
133146
134147 // msgParts will store a received multipart message.
135148 var msgParts [][]byte
@@ -147,8 +160,8 @@ func runKernel(connectionFile string) {
147160 switch socket := item .Socket ; socket {
148161
149162 // Handle shell messages.
150- case sockets .ShellSocket :
151- msgParts , err = sockets .ShellSocket .RecvMessageBytes (0 )
163+ case sockets .ShellSocket . Socket :
164+ msgParts , err = sockets .ShellSocket .Socket . RecvMessageBytes (0 )
152165 if err != nil {
153166 log .Println (err )
154167 }
@@ -162,12 +175,12 @@ func runKernel(connectionFile string) {
162175 handleShellMsg (ir , msgReceipt {msg , ids , sockets })
163176
164177 // TODO Handle stdin socket.
165- case sockets .StdinSocket :
166- sockets .StdinSocket .RecvMessageBytes (0 )
178+ case sockets .StdinSocket . Socket :
179+ sockets .StdinSocket .Socket . RecvMessageBytes (0 )
167180
168181 // Handle control messages.
169- case sockets .ControlSocket :
170- msgParts , err = sockets .ControlSocket .RecvMessageBytes (0 )
182+ case sockets .ControlSocket . Socket :
183+ msgParts , err = sockets .ControlSocket .Socket . RecvMessageBytes (0 )
171184 if err != nil {
172185 log .Println (err )
173186 return
@@ -200,46 +213,51 @@ func prepareSockets(connInfo ConnectionInfo) (SocketGroup, error) {
200213
201214 // Create the shell socket, a request-reply socket that may receive messages from multiple frontend for
202215 // code execution, introspection, auto-completion, etc.
203- sg .ShellSocket , err = context .NewSocket (zmq .ROUTER )
216+ sg .ShellSocket .Socket , err = context .NewSocket (zmq .ROUTER )
217+ sg .ShellSocket .Lock = & sync.Mutex {}
204218 if err != nil {
205219 return sg , err
206220 }
207221
208222 // Create the control socket. This socket is a duplicate of the shell socket where messages on this channel
209223 // should jump ahead of queued messages on the shell socket.
210- sg .ControlSocket , err = context .NewSocket (zmq .ROUTER )
224+ sg .ControlSocket .Socket , err = context .NewSocket (zmq .ROUTER )
225+ sg .ControlSocket .Lock = & sync.Mutex {}
211226 if err != nil {
212227 return sg , err
213228 }
214229
215230 // Create the stdin socket, a request-reply socket used to request user input from a front-end. This is analogous
216231 // to a standard input stream.
217- sg .StdinSocket , err = context .NewSocket (zmq .ROUTER )
232+ sg .StdinSocket .Socket , err = context .NewSocket (zmq .ROUTER )
233+ sg .StdinSocket .Lock = & sync.Mutex {}
218234 if err != nil {
219235 return sg , err
220236 }
221237
222238 // Create the iopub socket, a publisher for broadcasting data like stdout/stderr output, displaying execution
223239 // results or errors, kernel status, etc. to connected subscribers.
224- sg .IOPubSocket , err = context .NewSocket (zmq .PUB )
240+ sg .IOPubSocket .Socket , err = context .NewSocket (zmq .PUB )
241+ sg .IOPubSocket .Lock = & sync.Mutex {}
225242 if err != nil {
226243 return sg , err
227244 }
228245
229246 // Create the heartbeat socket, a request-reply socket that only allows alternating recv-send (request-reply)
230247 // calls. It should echo the byte strings it receives to let the requester know the kernel is still alive.
231- sg .HBSocket , err = context .NewSocket (zmq .REP )
248+ sg .HBSocket .Socket , err = context .NewSocket (zmq .REP )
249+ sg .HBSocket .Lock = & sync.Mutex {}
232250 if err != nil {
233251 return sg , err
234252 }
235253
236254 // Bind the sockets.
237255 address := fmt .Sprintf ("%v://%v:%%v" , connInfo .Transport , connInfo .IP )
238- sg .ShellSocket .Bind (fmt .Sprintf (address , connInfo .ShellPort ))
239- sg .ControlSocket .Bind (fmt .Sprintf (address , connInfo .ControlPort ))
240- sg .StdinSocket .Bind (fmt .Sprintf (address , connInfo .StdinPort ))
241- sg .IOPubSocket .Bind (fmt .Sprintf (address , connInfo .IOPubPort ))
242- sg .HBSocket .Bind (fmt .Sprintf (address , connInfo .HBPort ))
256+ sg .ShellSocket .Socket . Bind (fmt .Sprintf (address , connInfo .ShellPort ))
257+ sg .ControlSocket .Socket . Bind (fmt .Sprintf (address , connInfo .ControlPort ))
258+ sg .StdinSocket .Socket . Bind (fmt .Sprintf (address , connInfo .StdinPort ))
259+ sg .IOPubSocket .Socket . Bind (fmt .Sprintf (address , connInfo .IOPubPort ))
260+ sg .HBSocket .Socket . Bind (fmt .Sprintf (address , connInfo .HBPort ))
243261
244262 // Set the message signing key.
245263 sg .Key = []byte (connInfo .Key )
@@ -496,7 +514,7 @@ func handleShutdownRequest(receipt msgReceipt) {
496514// startHeartbeat starts a go-routine for handling heartbeat ping messages sent over the given `hbSocket`. The `wg`'s
497515// `Done` method is invoked after the thread is completely shutdown. To request a shutdown the returned `shutdown` channel
498516// can be closed.
499- func startHeartbeat (hbSocket * zmq. Socket , wg * sync.WaitGroup ) (shutdown chan struct {}) {
517+ func startHeartbeat (hbSocket Socket , wg * sync.WaitGroup ) (shutdown chan struct {}) {
500518 quit := make (chan struct {})
501519
502520 // Start the handler that will echo any received messages back to the sender.
@@ -506,7 +524,7 @@ func startHeartbeat(hbSocket *zmq.Socket, wg *sync.WaitGroup) (shutdown chan str
506524
507525 // Create a `Poller` to check for incoming messages.
508526 poller := zmq .NewPoller ()
509- poller .Add (hbSocket , zmq .POLLIN )
527+ poller .Add (hbSocket . Socket , zmq .POLLIN )
510528
511529 for {
512530 select {
@@ -521,16 +539,22 @@ func startHeartbeat(hbSocket *zmq.Socket, wg *sync.WaitGroup) (shutdown chan str
521539
522540 // If there is at least 1 message waiting then echo it.
523541 if len (pingEvents ) > 0 {
524- // Read a message from the heartbeat channel as a simple byte string.
525- pingMsg , err := hbSocket .RecvBytes (0 )
526- if err != nil {
527- log .Fatalf ("Error reading heartbeat ping bytes: %v\n " , err )
528- }
529-
530- // Send the received byte string back to let the front-end know that the kernel is alive.
531- if _ , err = hbSocket .SendBytes (pingMsg , 0 ); err != nil {
532- log .Printf ("Error sending heartbeat pong bytes: %b\n " , err )
533- }
542+ hbSocket .RunWithSocket (func (echo * zmq.Socket ) error {
543+ // Read a message from the heartbeat channel as a simple byte string.
544+ pingMsg , err := echo .RecvBytes (0 )
545+ if err != nil {
546+ log .Fatalf ("Error reading heartbeat ping bytes: %v\n " , err )
547+ return err
548+ }
549+
550+ // Send the received byte string back to let the front-end know that the kernel is alive.
551+ if _ , err = echo .SendBytes (pingMsg , 0 ); err != nil {
552+ log .Printf ("Error sending heartbeat pong bytes: %b\n " , err )
553+ return err
554+ }
555+
556+ return nil
557+ })
534558 }
535559 }
536560 }
0 commit comments