Skip to content

Commit 35950ad

Browse files
committed
Fix more race conditions
1 parent 8979076 commit 35950ad

File tree

6 files changed

+40
-31
lines changed

6 files changed

+40
-31
lines changed

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# wsep
22

3-
`wsep` is a high performance, ***W***eb***S***ocket command ***e***xecution ***p***rotocol. It can be thought of as SSH without encryption.
3+
`wsep` is a high performance,
4+
<strong style="font-size: 1.5em; text-decoration: underline;">w</strong>eb <strong style="font-size: 1.5em;text-decoration: underline;">s</strong>ocket command <strong style="font-size: 1.5em;text-decoration: underline;">e</strong>xecution <strong style="font-size: 1.5em;text-decoration: underline;">p</strong>rotocol. It can be thought of as SSH without encryption.
45

56
It's useful in cases where you want to provide a command exec interface into a remote environment. It's implemented
67
with WebSocket so it may be used directly by a browser frontend.
@@ -25,9 +26,7 @@ if err != nil {
2526
// handle error
2627
}
2728

28-
go io.Copy(os.Stdout, process.Stdout())
29-
go io.Copy(os.Stderr, process.Stderr())
30-
go io.Copy(process.Stdin(), os.Stdin)
29+
io.Copy(os.Stderr, process.Stderr())
3130

3231
err = process.Wait()
3332
if err != nil {

client_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ func TestRemoteExec(t *testing.T) {
7676
defer cancel()
7777

7878
ws, server := mockConn(ctx, t)
79-
defer ws.Close(websocket.StatusAbnormalClosure, "abnormal closure")
8079
defer server.Close()
8180

8281
execer := RemoteExecer(ws)

exec.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,14 @@ func (e ExitError) Error() string {
1919

2020
// Process represents a started command.
2121
type Process interface {
22+
// Pid is populated immediately during a successfull start with the process ID.
2223
Pid() int
24+
// Stdout returns an io.WriteCloser that will pipe writes to the remote command.
25+
// Closure of stdin sends the correspoding close messsage.
2326
Stdin() io.WriteCloser
27+
// Stdout returns an io.Reader that is connected to the command's standard output.
2428
Stdout() io.Reader
29+
// Stderr returns an io.Reader that is connected to the command's standard error.
2530
Stderr() io.Reader
2631
// Resize resizes the TTY if a TTY is enabled.
2732
Resize(ctx context.Context, rows, cols uint16) error

internal/proto/protocol.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,5 @@ func (h headerWriter) Write(b []byte) (int, error) {
4646
if err != nil {
4747
return 0, err
4848
}
49-
return len(b), nil // TODO: potential buggy
49+
return len(b), nil
5050
}

localexec_test.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,33 @@ func testExecer(ctx context.Context, t *testing.T, execer Execer) {
2424
Command: "pwd",
2525
})
2626
assert.Success(t, "start local cmd", err)
27+
var (
28+
stderr = process.Stderr()
29+
stdout = process.Stdout()
30+
wg sync.WaitGroup
31+
)
2732

28-
var wg sync.WaitGroup
2933
wg.Add(1)
3034
go func() {
3135
defer wg.Done()
3236

33-
stdout, err := ioutil.ReadAll(process.Stdout())
37+
stdoutByt, err := ioutil.ReadAll(stdout)
3438
assert.Success(t, "read stdout", err)
3539
wd, err := os.Getwd()
3640
assert.Success(t, "get real working dir", err)
3741

38-
assert.Equal(t, "stdout", wd, strings.TrimSuffix(string(stdout), "\n"))
42+
assert.Equal(t, "stdout", wd, strings.TrimSuffix(string(stdoutByt), "\n"))
3943
}()
4044
wg.Add(1)
4145
go func() {
4246
defer wg.Done()
4347

44-
stderr, err := ioutil.ReadAll(process.Stderr())
48+
stderrByt, err := ioutil.ReadAll(stderr)
4549
assert.Success(t, "read stderr", err)
46-
assert.True(t, "len stderr", len(stderr) == 0)
50+
assert.True(t, "len stderr", len(stderrByt) == 0)
4751
}()
4852

53+
wg.Wait()
4954
err = process.Wait()
5055
assert.Success(t, "wait for process to complete", err)
51-
52-
wg.Wait()
5356
}

server.go

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ import (
77
"errors"
88
"io"
99
"net"
10-
"sync"
1110

1211
"go.coder.com/flog"
12+
"golang.org/x/sync/errgroup"
1313
"golang.org/x/xerrors"
1414
"nhooyr.io/websocket"
1515

@@ -68,13 +68,23 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer) error {
6868
}
6969

7070
sendPID(ctx, process.Pid(), wsNetConn)
71-
var wg sync.WaitGroup
72-
go pipeProcessOutput(ctx, process, wsNetConn, &wg)
71+
72+
var outputgroup errgroup.Group
73+
outputgroup.Go(func() error {
74+
return copyWithHeader(process.Stdout(), wsNetConn, proto.Header{Type: proto.TypeStdout})
75+
})
76+
outputgroup.Go(func() error {
77+
return copyWithHeader(process.Stderr(), wsNetConn, proto.Header{Type: proto.TypeStdout})
78+
})
7379

7480
go func() {
7581
defer wsNetConn.Close()
82+
err := outputgroup.Wait()
83+
if err != nil {
84+
// connection should close without an exit code if copy fails
85+
return
86+
}
7687
err = process.Wait()
77-
wg.Wait()
7888
if exitErr, ok := err.(*ExitError); ok {
7989
sendExitCode(ctx, exitErr.Code, wsNetConn)
8090
return
@@ -131,22 +141,15 @@ func sendPID(ctx context.Context, pid int, conn net.Conn) {
131141
proto.WithHeader(conn, header).Write(nil)
132142
}
133143

134-
func pipeProcessOutput(ctx context.Context, process Process, conn net.Conn, wg *sync.WaitGroup) {
135-
var (
136-
stdout = process.Stdout()
137-
stderr = process.Stderr()
138-
)
139-
wg.Add(2)
140-
go copyWithHeader(stdout, conn, proto.Header{Type: proto.TypeStdout}, wg)
141-
go copyWithHeader(stderr, conn, proto.Header{Type: proto.TypeStderr}, wg)
142-
}
143-
144-
func copyWithHeader(r io.Reader, w io.Writer, header proto.Header, wg *sync.WaitGroup) {
145-
defer wg.Done()
144+
func copyWithHeader(r io.Reader, w io.Writer, header proto.Header) error {
146145
headerByt, err := json.Marshal(header)
147146
if err != nil {
148-
return
147+
return err
149148
}
150149
wr := proto.WithHeader(w, headerByt)
151-
io.Copy(wr, r)
150+
_, err = io.Copy(wr, r)
151+
if err != nil {
152+
return err
153+
}
154+
return nil
152155
}

0 commit comments

Comments
 (0)