Skip to content

Commit 8979076

Browse files
committed
Fix race conditions with stderr/stdout streams
1 parent 3e3e419 commit 8979076

File tree

4 files changed

+65
-36
lines changed

4 files changed

+65
-36
lines changed

client.go

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
package wsep
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
67
"io"
78
"net"
89

910
"cdr.dev/wsep/internal/proto"
10-
"go.coder.com/flog"
11+
"golang.org/x/sync/errgroup"
1112
"golang.org/x/xerrors"
1213
"nhooyr.io/websocket"
1314
)
@@ -123,46 +124,67 @@ func newPipe() pipe {
123124

124125
func (r remoteProcess) listen(ctx context.Context) {
125126
defer r.conn.Close(websocket.StatusNormalClosure, "normal closure")
126-
defer r.stdout.w.Close()
127-
defer r.stderr.w.Close()
128127

129-
for {
130-
if err := ctx.Err(); err != nil {
131-
r.done <- xerrors.Errorf("process canceled: %w", err)
132-
break
133-
}
134-
_, payload, err := r.conn.Read(ctx)
135-
if err != nil {
136-
continue
137-
}
138-
headerByt, body := proto.SplitMessage(payload)
128+
exitCode := make(chan int, 1)
129+
var eg errgroup.Group
139130

140-
var header proto.Header
141-
err = json.Unmarshal(headerByt, &header)
142-
if err != nil {
143-
continue
144-
}
131+
eg.Go(func() error {
132+
defer r.stdout.w.Close()
133+
defer r.stderr.w.Close()
145134

146-
switch header.Type {
147-
case proto.TypeStderr:
148-
go r.stderr.w.Write(body)
149-
case proto.TypeStdout:
150-
go r.stdout.w.Write(body)
151-
case proto.TypeExitCode:
152-
var exitMsg proto.ServerExitCodeHeader
153-
err = json.Unmarshal(headerByt, &exitMsg)
135+
buf := make([]byte, 32<<10) // max size of one websocket message
136+
for {
137+
if err := ctx.Err(); err != nil {
138+
r.done <- xerrors.Errorf("process canceled: %w", err)
139+
break
140+
}
141+
_, payload, err := r.conn.Read(ctx)
154142
if err != nil {
155-
flog.Error("failed to unmarshal exit code message: %v", err)
156143
continue
157144
}
145+
headerByt, body := proto.SplitMessage(payload)
158146

159-
var err error = ExitError{Code: exitMsg.ExitCode}
160-
if exitMsg.ExitCode == 0 {
161-
err = nil
147+
var header proto.Header
148+
err = json.Unmarshal(headerByt, &header)
149+
if err != nil {
150+
continue
151+
}
152+
153+
switch header.Type {
154+
case proto.TypeStderr:
155+
_, err = io.CopyBuffer(r.stderr.w, bytes.NewReader(body), buf)
156+
if err != nil {
157+
return err
158+
}
159+
case proto.TypeStdout:
160+
_, err = io.CopyBuffer(r.stdout.w, bytes.NewReader(body), buf)
161+
if err != nil {
162+
return err
163+
}
164+
case proto.TypeExitCode:
165+
var exitMsg proto.ServerExitCodeHeader
166+
err = json.Unmarshal(headerByt, &exitMsg)
167+
if err != nil {
168+
continue
169+
}
170+
171+
exitCode <- exitMsg.ExitCode
172+
return nil
162173
}
163-
r.done <- err
174+
}
175+
return nil
176+
})
177+
178+
err := eg.Wait()
179+
select {
180+
case exitCode := <-exitCode:
181+
if exitCode != 0 {
182+
r.done <- ExitError{Code: int(exitCode)}
164183
return
165184
}
185+
r.done <- nil
186+
default:
187+
r.done <- err
166188
}
167189
}
168190

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ require (
1010
go.coder.com/cli v0.4.0
1111
go.coder.com/flog v0.0.0-20190906214207-47dd47ea0512
1212
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413
13+
golang.org/x/sync v0.0.0-20190423024810-112230192c58
1314
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543
1415
nhooyr.io/websocket v1.8.6
1516
)

go.sum

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ
206206
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
207207
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
208208
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
209+
golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU=
209210
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
210211
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
211212
golang.org/x/sys v0.0.0-20181128092732-4ed8d59d0b35/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

server.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"io"
99
"net"
10+
"sync"
1011

1112
"go.coder.com/flog"
1213
"golang.org/x/xerrors"
@@ -67,11 +68,13 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer) error {
6768
}
6869

6970
sendPID(ctx, process.Pid(), wsNetConn)
70-
go pipeProcessOutput(ctx, process, wsNetConn)
71+
var wg sync.WaitGroup
72+
go pipeProcessOutput(ctx, process, wsNetConn, &wg)
7173

7274
go func() {
7375
defer wsNetConn.Close()
7476
err = process.Wait()
77+
wg.Wait()
7578
if exitErr, ok := err.(*ExitError); ok {
7679
sendExitCode(ctx, exitErr.Code, wsNetConn)
7780
return
@@ -128,16 +131,18 @@ func sendPID(ctx context.Context, pid int, conn net.Conn) {
128131
proto.WithHeader(conn, header).Write(nil)
129132
}
130133

131-
func pipeProcessOutput(ctx context.Context, process Process, conn net.Conn) {
134+
func pipeProcessOutput(ctx context.Context, process Process, conn net.Conn, wg *sync.WaitGroup) {
132135
var (
133136
stdout = process.Stdout()
134137
stderr = process.Stderr()
135138
)
136-
go copyWithHeader(stdout, conn, proto.Header{Type: proto.TypeStdout})
137-
go copyWithHeader(stderr, conn, proto.Header{Type: proto.TypeStderr})
139+
wg.Add(2)
140+
go copyWithHeader(stdout, conn, proto.Header{Type: proto.TypeStdout}, wg)
141+
go copyWithHeader(stderr, conn, proto.Header{Type: proto.TypeStderr}, wg)
138142
}
139143

140-
func copyWithHeader(r io.Reader, w io.Writer, header proto.Header) {
144+
func copyWithHeader(r io.Reader, w io.Writer, header proto.Header, wg *sync.WaitGroup) {
145+
defer wg.Done()
141146
headerByt, err := json.Marshal(header)
142147
if err != nil {
143148
return

0 commit comments

Comments
 (0)