11package main
22
33import (
4- "bufio"
5- "bytes"
64 "encoding/json"
5+ "errors"
76 "fmt"
7+ "go/ast"
88 "io"
99 "io/ioutil"
1010 "log"
1111 "os"
1212 "runtime"
13- "strings"
13+ "sync"
14+ "time"
1415
16+ "github.com/cosmos72/gomacro/ast2"
1517 "github.com/cosmos72/gomacro/base"
1618 "github.com/cosmos72/gomacro/classic"
1719 zmq "github.com/pebbe/zmq4"
@@ -41,6 +43,7 @@ type SocketGroup struct {
4143 ControlSocket * zmq.Socket
4244 StdinSocket * zmq.Socket
4345 IOPubSocket * zmq.Socket
46+ HBSocket * zmq.Socket
4447 Key []byte
4548}
4649
@@ -71,7 +74,7 @@ type kernelInfo struct {
7174 HelpLinks []helpLink `json:"help_links"`
7275}
7376
74- // shutdownReply encodes a boolean indication of stutdown /restart.
77+ // shutdownReply encodes a boolean indication of shutdown /restart.
7578type shutdownReply struct {
7679 Restart bool `json:"restart"`
7780}
@@ -88,6 +91,10 @@ func runKernel(connectionFile string) {
8891 // Set up the "Session" with the replpkg.
8992 ir := classic .New ()
9093
94+ // Throw out the error/warning messages that gomacro outputs writes to these streams.
95+ ir .Stdout = ioutil .Discard
96+ ir .Stderr = ioutil .Discard
97+
9198 // Parse the connection info.
9299 var connInfo ConnectionInfo
93100
@@ -106,6 +113,13 @@ func runKernel(connectionFile string) {
106113 log .Fatal (err )
107114 }
108115
116+ // TODO connect all channel handlers to a WaitGroup to ensure shutdown before returning from runKernel.
117+
118+ // Start up the heartbeat handler.
119+ startHeartbeat (sockets .HBSocket , & sync.WaitGroup {})
120+
121+ // TODO gracefully shutdown the heartbeat handler on kernel shutdown by closing the chan returned by startHeartbeat.
122+
109123 poller := zmq .NewPoller ()
110124 poller .Add (sockets .ShellSocket , zmq .POLLIN )
111125 poller .Add (sockets .StdinSocket , zmq .POLLIN )
@@ -116,7 +130,6 @@ func runKernel(connectionFile string) {
116130
117131 // Start a message receiving loop.
118132 for {
119-
120133 polled , err := poller .Poll (- 1 )
121134 if err != nil {
122135 log .Fatal (err )
@@ -179,32 +192,48 @@ func prepareSockets(connInfo ConnectionInfo) (SocketGroup, error) {
179192 // Initialize the socket group.
180193 var sg SocketGroup
181194
195+ // Create the shell socket, a request-reply socket that may receive messages from multiple frontend for
196+ // code execution, introspection, auto-completion, etc.
182197 sg .ShellSocket , err = context .NewSocket (zmq .ROUTER )
183198 if err != nil {
184199 return sg , err
185200 }
186201
202+ // Create the control socket. This socket is a duplicate of the shell socket where messages on this channel
203+ // should jump ahead of queued messages on the shell socket.
187204 sg .ControlSocket , err = context .NewSocket (zmq .ROUTER )
188205 if err != nil {
189206 return sg , err
190207 }
191208
209+ // Create the stdin socket, a request-reply socket used to request user input from a front-end. This is analogous
210+ // to a standard input stream.
192211 sg .StdinSocket , err = context .NewSocket (zmq .ROUTER )
193212 if err != nil {
194213 return sg , err
195214 }
196215
216+ // Create the iopub socket, a publisher for broadcasting data like stdout/stderr output, displaying execution
217+ // results or errors, kernel status, etc. to connected subscribers.
197218 sg .IOPubSocket , err = context .NewSocket (zmq .PUB )
198219 if err != nil {
199220 return sg , err
200221 }
201222
223+ // Create the heartbeat socket, a request-reply socket that only allows alternating recv-send (request-reply)
224+ // calls. It should echo the byte strings it receives to let the requester know the kernel is still alive.
225+ sg .HBSocket , err = context .NewSocket (zmq .REP )
226+ if err != nil {
227+ return sg , err
228+ }
229+
202230 // Bind the sockets.
203231 address := fmt .Sprintf ("%v://%v:%%v" , connInfo .Transport , connInfo .IP )
204232 sg .ShellSocket .Bind (fmt .Sprintf (address , connInfo .ShellPort ))
205233 sg .ControlSocket .Bind (fmt .Sprintf (address , connInfo .ControlPort ))
206234 sg .StdinSocket .Bind (fmt .Sprintf (address , connInfo .StdinPort ))
207235 sg .IOPubSocket .Bind (fmt .Sprintf (address , connInfo .IOPubPort ))
236+ sg .HBSocket .Bind (fmt .Sprintf (address , connInfo .HBPort ))
208237
209238 // Set the message signing key.
210239 sg .Key = []byte (connInfo .Key )
@@ -254,10 +283,10 @@ func sendKernelInfo(receipt msgReceipt) error {
254283// handleExecuteRequest runs code from an execute_request method,
255284// and sends the various reply messages.
256285func handleExecuteRequest (ir * classic.Interp , receipt msgReceipt ) error {
257- // Extract the data from the request
286+
287+ // Extract the data from the request.
258288 reqcontent := receipt .Msg .Content .(map [string ]interface {})
259289 code := reqcontent ["code" ].(string )
260- in := bufio .NewReader (strings .NewReader (code ))
261290 silent := reqcontent ["silent" ].(bool )
262291
263292 if ! silent {
@@ -293,78 +322,60 @@ func handleExecuteRequest(ir *classic.Interp, receipt msgReceipt) error {
293322 os .Stdout = wOut
294323
295324 // Redirect the standard error from the REPL.
325+ oldStderr := os .Stderr
296326 rErr , wErr , err := os .Pipe ()
297327 if err != nil {
298328 return err
299329 }
300- ir .Stderr = wErr
330+ os .Stderr = wErr
301331
302- // Prepare and perform the multiline evaluation.
303- env := ir .Env
304- env .Options &^= base .OptShowPrompt
305- env .Line = 0
332+ var writersWG sync.WaitGroup
333+ writersWG .Add (2 )
306334
307- // Perform the first iteration manually, to collect comments.
308- var comments string
309- str , firstToken := env .ReadMultiline (in , base .ReadOptCollectAllComments )
310- if firstToken >= 0 {
311- comments = str [0 :firstToken ]
312- if firstToken > 0 {
313- str = str [firstToken :]
314- env .IncLine (comments )
315- }
316- }
317- if ir .ParseEvalPrint (str , in ) {
318- ir .Repl (in )
319- }
335+ // Forward all data written to stdout/stderr to the front-end.
336+ go func () {
337+ defer writersWG .Done ()
338+ jupyterStdOut := JupyterStreamWriter {StreamStdout , & receipt }
339+ io .Copy (& jupyterStdOut , rOut )
340+ }()
320341
321- // Copy the stdout in a separate goroutine to prevent
322- // blocking on printing.
323- outStdout := make (chan string )
324342 go func () {
325- var buf bytes. Buffer
326- io . Copy ( & buf , rOut )
327- outStdout <- buf . String ( )
343+ defer writersWG . Done ()
344+ jupyterStdErr := JupyterStreamWriter { StreamStderr , & receipt }
345+ io . Copy ( & jupyterStdErr , rErr )
328346 }()
329347
330- // Return stdout back to normal state.
348+ vals , executionErr := doEval (ir , code )
349+
350+ //TODO if value is a certain type like image then display it instead
351+
352+ // Close and restore the streams.
331353 wOut .Close ()
332354 os .Stdout = oldStdout
333- val := <- outStdout
334-
335- // Copy the stderr in a separate goroutine to prevent
336- // blocking on printing.
337- outStderr := make (chan string )
338- go func () {
339- var buf bytes.Buffer
340- io .Copy (& buf , rErr )
341- outStderr <- buf .String ()
342- }()
343355
344356 wErr .Close ()
345- stdErr := <- outStderr
357+ os . Stderr = oldStderr
346358
347- // TODO write stdout and stderr to streams rather than publishing as results
359+ // Wait for the writers to finish forwarding the data.
360+ writersWG .Wait ()
348361
349- if len ( val ) > 0 {
362+ if executionErr == nil {
350363 content ["status" ] = "ok"
351364 content ["user_expressions" ] = make (map [string ]string )
352365
353- if ! silent {
366+ if ! silent && vals != nil {
354367 // Publish the result of the execution.
355- if err := receipt .PublishExecutionResult (ExecCounter , val ); err != nil {
368+ if err := receipt .PublishExecutionResult (ExecCounter , fmt . Sprint ( vals ... ) ); err != nil {
356369 log .Printf ("Error publishing execution result: %v\n " , err )
357370 }
358371 }
359- }
360-
361- if len (stdErr ) > 0 {
372+ } else {
362373 content ["status" ] = "error"
363374 content ["ename" ] = "ERROR"
364- content ["evalue" ] = stdErr
375+ content ["evalue" ] = executionErr . Error ()
365376 content ["traceback" ] = nil
366377
367- if err := receipt .PublishExecutionError (stdErr , []string {stdErr }); err != nil {
378+ if err := receipt .PublishExecutionError (executionErr . Error () , []string {executionErr . Error () }); err != nil {
368379 log .Printf ("Error publishing execution error: %v\n " , err )
369380 }
370381 }
@@ -373,6 +384,86 @@ func handleExecuteRequest(ir *classic.Interp, receipt msgReceipt) error {
373384 return receipt .Reply ("execute_reply" , content )
374385}
375386
387+ // doEval evaluates the code in the interpreter. This function captures an uncaught panic
388+ // as well as the values of the last statement/expression.
389+ func doEval (ir * classic.Interp , code string ) (_ []interface {}, err error ) {
390+
391+ // Capture a panic from the evaluation if one occurs and store it in the `err` return parameter.
392+ defer func () {
393+ if r := recover (); r != nil {
394+ var ok bool
395+ if err , ok = r .(error ); ! ok {
396+ err = errors .New (fmt .Sprint (r ))
397+ }
398+ }
399+ }()
400+
401+ // Prepare and perform the multiline evaluation.
402+ env := ir .Env
403+
404+ // Don't show the gomacro prompt.
405+ env .Options &^= base .OptShowPrompt
406+
407+ // Don't swallow panics as they are recovered above and handled with a Jupyter `error` message instead.
408+ env .Options &^= base .OptTrapPanic
409+
410+ // Reset the error line so that error messages correspond to the lines from the cell.
411+ env .Line = 0
412+
413+ // Parse the input code (and don't preform gomacro's macroexpansion).
414+ src := ir .ParseOnly (code )
415+
416+ if src == nil {
417+ return nil , nil
418+ }
419+
420+ // Check if the last node is an expression.
421+ var srcEndsWithExpr bool
422+
423+ // If the parsed ast is a single node, check if the node implements `ast.Expr`. Otherwise if the is multiple
424+ // nodes then just check if the last one is an expression. These are currently the 2 cases to consider from
425+ // gomacro's `ParseOnly`.
426+ if srcAstWithNode , ok := src .(ast2.AstWithNode ); ok {
427+ _ , srcEndsWithExpr = srcAstWithNode .Node ().(ast.Expr )
428+ } else if srcNodeSlice , ok := src .(ast2.NodeSlice ); ok {
429+ nodes := srcNodeSlice .X
430+ _ , srcEndsWithExpr = nodes [len (nodes )- 1 ].(ast.Expr )
431+ }
432+
433+ // Evaluate the code.
434+ result , results := ir .EvalAst (src )
435+
436+ // If the source ends with an expression, then the result of the execution is the value of the expression. In the
437+ // event that all return values are nil, the result is also nil.
438+ if srcEndsWithExpr {
439+ // `len(results) == 0` implies a single result stored in `result`.
440+ if len (results ) == 0 {
441+ if val := base .ValueInterface (result ); val != nil {
442+ return []interface {}{val }, nil
443+ }
444+ return nil , nil
445+ }
446+
447+ // Count the number of non-nil values in the output. If they are all nil then the output is skipped.
448+ nonNilCount := 0
449+ var values []interface {}
450+ for _ , result := range results {
451+ val := base .ValueInterface (result )
452+ if val != nil {
453+ nonNilCount ++
454+ }
455+ values = append (values , val )
456+ }
457+
458+ if nonNilCount > 0 {
459+ return values , nil
460+ }
461+ return nil , nil
462+ }
463+
464+ return nil , nil
465+ }
466+
376467// handleShutdownRequest sends a "shutdown" message.
377468func handleShutdownRequest (receipt msgReceipt ) {
378469 content := receipt .Msg .Content .(map [string ]interface {})
@@ -389,3 +480,49 @@ func handleShutdownRequest(receipt msgReceipt) {
389480 log .Println ("Shutting down in response to shutdown_request" )
390481 os .Exit (0 )
391482}
483+
484+ // startHeartbeat starts a go-routine for handling heartbeat ping messages sent over the given `hbSocket`. The `wg`'s
485+ // `Done` method is invoked after the thread is completely shutdown. To request a shutdown the returned `shutdown` channel
486+ // can be closed.
487+ func startHeartbeat (hbSocket * zmq.Socket , wg * sync.WaitGroup ) (shutdown chan struct {}) {
488+ quit := make (chan struct {})
489+
490+ // Start the handler that will echo any received messages back to the sender.
491+ wg .Add (1 )
492+ go func () {
493+ defer wg .Done ()
494+
495+ // Create a `Poller` to check for incoming messages.
496+ poller := zmq .NewPoller ()
497+ poller .Add (hbSocket , zmq .POLLIN )
498+
499+ for {
500+ select {
501+ case <- quit :
502+ return
503+ default :
504+ // Check for received messages waiting at most 500ms for once to arrive.
505+ pingEvents , err := poller .Poll (500 * time .Millisecond )
506+ if err != nil {
507+ log .Fatalf ("Error polling heartbeat channel: %v\n " , err )
508+ }
509+
510+ // If there is at least 1 message waiting then echo it.
511+ if len (pingEvents ) > 0 {
512+ // Read a message from the heartbeat channel as a simple byte string.
513+ pingMsg , err := hbSocket .RecvBytes (0 )
514+ if err != nil {
515+ log .Fatalf ("Error reading heartbeat ping bytes: %v\n " , err )
516+ }
517+
518+ // Send the received byte string back to let the front-end know that the kernel is alive.
519+ if _ , err = hbSocket .SendBytes (pingMsg , 0 ); err != nil {
520+ log .Printf ("Error sending heartbeat pong bytes: %b\n " , err )
521+ }
522+ }
523+ }
524+ }
525+ }()
526+
527+ return quit
528+ }
0 commit comments