Skip to content

Commit b0aca77

Browse files
authored
Merge pull request #79 from SpencerPark/feature/live-output
Live stream output and improved execution semantics
2 parents 5a3b5d7 + 1d6b207 commit b0aca77

File tree

3 files changed

+476
-115
lines changed

3 files changed

+476
-115
lines changed

kernel.go

Lines changed: 189 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
package main
22

33
import (
4-
"bufio"
5-
"bytes"
64
"encoding/json"
5+
"errors"
76
"fmt"
7+
"go/ast"
88
"io"
99
"io/ioutil"
1010
"log"
1111
"os"
1212
"runtime"
13-
"strings"
13+
"sync"
14+
"time"
1415

16+
"github.com/cosmos72/gomacro/ast2"
1517
"github.com/cosmos72/gomacro/base"
1618
"github.com/cosmos72/gomacro/classic"
1719
zmq "github.com/pebbe/zmq4"
@@ -41,6 +43,7 @@ type SocketGroup struct {
4143
ControlSocket *zmq.Socket
4244
StdinSocket *zmq.Socket
4345
IOPubSocket *zmq.Socket
46+
HBSocket *zmq.Socket
4447
Key []byte
4548
}
4649

@@ -71,7 +74,7 @@ type kernelInfo struct {
7174
HelpLinks []helpLink `json:"help_links"`
7275
}
7376

74-
// shutdownReply encodes a boolean indication of stutdown/restart.
77+
// shutdownReply encodes a boolean indication of shutdown/restart.
7578
type shutdownReply struct {
7679
Restart bool `json:"restart"`
7780
}
@@ -88,6 +91,10 @@ func runKernel(connectionFile string) {
8891
// Set up the "Session" with the replpkg.
8992
ir := classic.New()
9093

94+
// Throw out the error/warning messages that gomacro outputs writes to these streams.
95+
ir.Stdout = ioutil.Discard
96+
ir.Stderr = ioutil.Discard
97+
9198
// Parse the connection info.
9299
var connInfo ConnectionInfo
93100

@@ -106,6 +113,13 @@ func runKernel(connectionFile string) {
106113
log.Fatal(err)
107114
}
108115

116+
// TODO connect all channel handlers to a WaitGroup to ensure shutdown before returning from runKernel.
117+
118+
// Start up the heartbeat handler.
119+
startHeartbeat(sockets.HBSocket, &sync.WaitGroup{})
120+
121+
// TODO gracefully shutdown the heartbeat handler on kernel shutdown by closing the chan returned by startHeartbeat.
122+
109123
poller := zmq.NewPoller()
110124
poller.Add(sockets.ShellSocket, zmq.POLLIN)
111125
poller.Add(sockets.StdinSocket, zmq.POLLIN)
@@ -116,7 +130,6 @@ func runKernel(connectionFile string) {
116130

117131
// Start a message receiving loop.
118132
for {
119-
120133
polled, err := poller.Poll(-1)
121134
if err != nil {
122135
log.Fatal(err)
@@ -179,32 +192,48 @@ func prepareSockets(connInfo ConnectionInfo) (SocketGroup, error) {
179192
// Initialize the socket group.
180193
var sg SocketGroup
181194

195+
// Create the shell socket, a request-reply socket that may receive messages from multiple frontend for
196+
// code execution, introspection, auto-completion, etc.
182197
sg.ShellSocket, err = context.NewSocket(zmq.ROUTER)
183198
if err != nil {
184199
return sg, err
185200
}
186201

202+
// Create the control socket. This socket is a duplicate of the shell socket where messages on this channel
203+
// should jump ahead of queued messages on the shell socket.
187204
sg.ControlSocket, err = context.NewSocket(zmq.ROUTER)
188205
if err != nil {
189206
return sg, err
190207
}
191208

209+
// Create the stdin socket, a request-reply socket used to request user input from a front-end. This is analogous
210+
// to a standard input stream.
192211
sg.StdinSocket, err = context.NewSocket(zmq.ROUTER)
193212
if err != nil {
194213
return sg, err
195214
}
196215

216+
// Create the iopub socket, a publisher for broadcasting data like stdout/stderr output, displaying execution
217+
// results or errors, kernel status, etc. to connected subscribers.
197218
sg.IOPubSocket, err = context.NewSocket(zmq.PUB)
198219
if err != nil {
199220
return sg, err
200221
}
201222

223+
// Create the heartbeat socket, a request-reply socket that only allows alternating recv-send (request-reply)
224+
// calls. It should echo the byte strings it receives to let the requester know the kernel is still alive.
225+
sg.HBSocket, err = context.NewSocket(zmq.REP)
226+
if err != nil {
227+
return sg, err
228+
}
229+
202230
// Bind the sockets.
203231
address := fmt.Sprintf("%v://%v:%%v", connInfo.Transport, connInfo.IP)
204232
sg.ShellSocket.Bind(fmt.Sprintf(address, connInfo.ShellPort))
205233
sg.ControlSocket.Bind(fmt.Sprintf(address, connInfo.ControlPort))
206234
sg.StdinSocket.Bind(fmt.Sprintf(address, connInfo.StdinPort))
207235
sg.IOPubSocket.Bind(fmt.Sprintf(address, connInfo.IOPubPort))
236+
sg.HBSocket.Bind(fmt.Sprintf(address, connInfo.HBPort))
208237

209238
// Set the message signing key.
210239
sg.Key = []byte(connInfo.Key)
@@ -254,10 +283,10 @@ func sendKernelInfo(receipt msgReceipt) error {
254283
// handleExecuteRequest runs code from an execute_request method,
255284
// and sends the various reply messages.
256285
func handleExecuteRequest(ir *classic.Interp, receipt msgReceipt) error {
257-
// Extract the data from the request
286+
287+
// Extract the data from the request.
258288
reqcontent := receipt.Msg.Content.(map[string]interface{})
259289
code := reqcontent["code"].(string)
260-
in := bufio.NewReader(strings.NewReader(code))
261290
silent := reqcontent["silent"].(bool)
262291

263292
if !silent {
@@ -293,78 +322,60 @@ func handleExecuteRequest(ir *classic.Interp, receipt msgReceipt) error {
293322
os.Stdout = wOut
294323

295324
// Redirect the standard error from the REPL.
325+
oldStderr := os.Stderr
296326
rErr, wErr, err := os.Pipe()
297327
if err != nil {
298328
return err
299329
}
300-
ir.Stderr = wErr
330+
os.Stderr = wErr
301331

302-
// Prepare and perform the multiline evaluation.
303-
env := ir.Env
304-
env.Options &^= base.OptShowPrompt
305-
env.Line = 0
332+
var writersWG sync.WaitGroup
333+
writersWG.Add(2)
306334

307-
// Perform the first iteration manually, to collect comments.
308-
var comments string
309-
str, firstToken := env.ReadMultiline(in, base.ReadOptCollectAllComments)
310-
if firstToken >= 0 {
311-
comments = str[0:firstToken]
312-
if firstToken > 0 {
313-
str = str[firstToken:]
314-
env.IncLine(comments)
315-
}
316-
}
317-
if ir.ParseEvalPrint(str, in) {
318-
ir.Repl(in)
319-
}
335+
// Forward all data written to stdout/stderr to the front-end.
336+
go func() {
337+
defer writersWG.Done()
338+
jupyterStdOut := JupyterStreamWriter{StreamStdout, &receipt}
339+
io.Copy(&jupyterStdOut, rOut)
340+
}()
320341

321-
// Copy the stdout in a separate goroutine to prevent
322-
// blocking on printing.
323-
outStdout := make(chan string)
324342
go func() {
325-
var buf bytes.Buffer
326-
io.Copy(&buf, rOut)
327-
outStdout <- buf.String()
343+
defer writersWG.Done()
344+
jupyterStdErr := JupyterStreamWriter{StreamStderr, &receipt}
345+
io.Copy(&jupyterStdErr, rErr)
328346
}()
329347

330-
// Return stdout back to normal state.
348+
vals, executionErr := doEval(ir, code)
349+
350+
//TODO if value is a certain type like image then display it instead
351+
352+
// Close and restore the streams.
331353
wOut.Close()
332354
os.Stdout = oldStdout
333-
val := <-outStdout
334-
335-
// Copy the stderr in a separate goroutine to prevent
336-
// blocking on printing.
337-
outStderr := make(chan string)
338-
go func() {
339-
var buf bytes.Buffer
340-
io.Copy(&buf, rErr)
341-
outStderr <- buf.String()
342-
}()
343355

344356
wErr.Close()
345-
stdErr := <-outStderr
357+
os.Stderr = oldStderr
346358

347-
// TODO write stdout and stderr to streams rather than publishing as results
359+
// Wait for the writers to finish forwarding the data.
360+
writersWG.Wait()
348361

349-
if len(val) > 0 {
362+
if executionErr == nil {
350363
content["status"] = "ok"
351364
content["user_expressions"] = make(map[string]string)
352365

353-
if !silent {
366+
if !silent && vals != nil {
354367
// Publish the result of the execution.
355-
if err := receipt.PublishExecutionResult(ExecCounter, val); err != nil {
368+
if err := receipt.PublishExecutionResult(ExecCounter, fmt.Sprint(vals...)); err != nil {
356369
log.Printf("Error publishing execution result: %v\n", err)
357370
}
358371
}
359-
}
360-
361-
if len(stdErr) > 0 {
372+
} else {
362373
content["status"] = "error"
363374
content["ename"] = "ERROR"
364-
content["evalue"] = stdErr
375+
content["evalue"] = executionErr.Error()
365376
content["traceback"] = nil
366377

367-
if err := receipt.PublishExecutionError(stdErr, []string{stdErr}); err != nil {
378+
if err := receipt.PublishExecutionError(executionErr.Error(), []string{executionErr.Error()}); err != nil {
368379
log.Printf("Error publishing execution error: %v\n", err)
369380
}
370381
}
@@ -373,6 +384,86 @@ func handleExecuteRequest(ir *classic.Interp, receipt msgReceipt) error {
373384
return receipt.Reply("execute_reply", content)
374385
}
375386

387+
// doEval evaluates the code in the interpreter. This function captures an uncaught panic
388+
// as well as the values of the last statement/expression.
389+
func doEval(ir *classic.Interp, code string) (_ []interface{}, err error) {
390+
391+
// Capture a panic from the evaluation if one occurs and store it in the `err` return parameter.
392+
defer func() {
393+
if r := recover(); r != nil {
394+
var ok bool
395+
if err, ok = r.(error); !ok {
396+
err = errors.New(fmt.Sprint(r))
397+
}
398+
}
399+
}()
400+
401+
// Prepare and perform the multiline evaluation.
402+
env := ir.Env
403+
404+
// Don't show the gomacro prompt.
405+
env.Options &^= base.OptShowPrompt
406+
407+
// Don't swallow panics as they are recovered above and handled with a Jupyter `error` message instead.
408+
env.Options &^= base.OptTrapPanic
409+
410+
// Reset the error line so that error messages correspond to the lines from the cell.
411+
env.Line = 0
412+
413+
// Parse the input code (and don't preform gomacro's macroexpansion).
414+
src := ir.ParseOnly(code)
415+
416+
if src == nil {
417+
return nil, nil
418+
}
419+
420+
// Check if the last node is an expression.
421+
var srcEndsWithExpr bool
422+
423+
// If the parsed ast is a single node, check if the node implements `ast.Expr`. Otherwise if the is multiple
424+
// nodes then just check if the last one is an expression. These are currently the 2 cases to consider from
425+
// gomacro's `ParseOnly`.
426+
if srcAstWithNode, ok := src.(ast2.AstWithNode); ok {
427+
_, srcEndsWithExpr = srcAstWithNode.Node().(ast.Expr)
428+
} else if srcNodeSlice, ok := src.(ast2.NodeSlice); ok {
429+
nodes := srcNodeSlice.X
430+
_, srcEndsWithExpr = nodes[len(nodes)-1].(ast.Expr)
431+
}
432+
433+
// Evaluate the code.
434+
result, results := ir.EvalAst(src)
435+
436+
// If the source ends with an expression, then the result of the execution is the value of the expression. In the
437+
// event that all return values are nil, the result is also nil.
438+
if srcEndsWithExpr {
439+
// `len(results) == 0` implies a single result stored in `result`.
440+
if len(results) == 0 {
441+
if val := base.ValueInterface(result); val != nil {
442+
return []interface{}{val}, nil
443+
}
444+
return nil, nil
445+
}
446+
447+
// Count the number of non-nil values in the output. If they are all nil then the output is skipped.
448+
nonNilCount := 0
449+
var values []interface{}
450+
for _, result := range results {
451+
val := base.ValueInterface(result)
452+
if val != nil {
453+
nonNilCount++
454+
}
455+
values = append(values, val)
456+
}
457+
458+
if nonNilCount > 0 {
459+
return values, nil
460+
}
461+
return nil, nil
462+
}
463+
464+
return nil, nil
465+
}
466+
376467
// handleShutdownRequest sends a "shutdown" message.
377468
func handleShutdownRequest(receipt msgReceipt) {
378469
content := receipt.Msg.Content.(map[string]interface{})
@@ -389,3 +480,49 @@ func handleShutdownRequest(receipt msgReceipt) {
389480
log.Println("Shutting down in response to shutdown_request")
390481
os.Exit(0)
391482
}
483+
484+
// startHeartbeat starts a go-routine for handling heartbeat ping messages sent over the given `hbSocket`. The `wg`'s
485+
// `Done` method is invoked after the thread is completely shutdown. To request a shutdown the returned `shutdown` channel
486+
// can be closed.
487+
func startHeartbeat(hbSocket *zmq.Socket, wg *sync.WaitGroup) (shutdown chan struct{}) {
488+
quit := make(chan struct{})
489+
490+
// Start the handler that will echo any received messages back to the sender.
491+
wg.Add(1)
492+
go func() {
493+
defer wg.Done()
494+
495+
// Create a `Poller` to check for incoming messages.
496+
poller := zmq.NewPoller()
497+
poller.Add(hbSocket, zmq.POLLIN)
498+
499+
for {
500+
select {
501+
case <-quit:
502+
return
503+
default:
504+
// Check for received messages waiting at most 500ms for once to arrive.
505+
pingEvents, err := poller.Poll(500 * time.Millisecond)
506+
if err != nil {
507+
log.Fatalf("Error polling heartbeat channel: %v\n", err)
508+
}
509+
510+
// If there is at least 1 message waiting then echo it.
511+
if len(pingEvents) > 0 {
512+
// Read a message from the heartbeat channel as a simple byte string.
513+
pingMsg, err := hbSocket.RecvBytes(0)
514+
if err != nil {
515+
log.Fatalf("Error reading heartbeat ping bytes: %v\n", err)
516+
}
517+
518+
// Send the received byte string back to let the front-end know that the kernel is alive.
519+
if _, err = hbSocket.SendBytes(pingMsg, 0); err != nil {
520+
log.Printf("Error sending heartbeat pong bytes: %b\n", err)
521+
}
522+
}
523+
}
524+
}
525+
}()
526+
527+
return quit
528+
}

0 commit comments

Comments
 (0)