Skip to content
This repository was archived by the owner on Jan 21, 2020. It is now read-only.

Commit c88ca75

Browse files
author
David Chung
authored
fixing exec race conditions (#467)
Signed-off-by: David Chung <[email protected]>
1 parent 8be63aa commit c88ca75

File tree

3 files changed

+109
-96
lines changed

3 files changed

+109
-96
lines changed

pkg/cli/local/context.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -278,18 +278,22 @@ func (c *Context) loadBackend() error {
278278
cmd := strings.Join(append([]string{"/bin/sh"}, opts...), " ")
279279
log.Debug("sh", "cmd", cmd)
280280

281-
return exec.Command(cmd).
282-
InheritEnvs(true).StartWithStreams(
283-
284-
exec.Do(exec.SendInput(
285-
func(stdin io.WriteCloser) error {
286-
_, err := stdin.Write([]byte(script))
287-
return err
288-
})).Then(
289-
exec.RedirectStdout(os.Stdout)).Then(
290-
exec.RedirectStderr(os.Stderr),
291-
).Done(),
281+
run := exec.Command(cmd)
282+
run.InheritEnvs(true).StartWithHandlers(
283+
func(stdin io.Writer) error {
284+
_, err := stdin.Write([]byte(script))
285+
return err
286+
},
287+
func(stdout io.Reader) error {
288+
_, err := io.Copy(os.Stdout, stdout)
289+
return err
290+
},
291+
func(stderr io.Reader) error {
292+
_, err := io.Copy(os.Stderr, stderr)
293+
return err
294+
},
292295
)
296+
return run.Wait()
293297
}
294298
return ""
295299
})

pkg/util/exec/exec.go

Lines changed: 83 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"os"
77
"os/exec"
88
"strings"
9+
"sync"
910

1011
logutil "github.com/docker/infrakit/pkg/log"
1112
"github.com/docker/infrakit/pkg/template"
@@ -34,6 +35,28 @@ type Builder struct {
3435
context interface{}
3536
rendered string // rendered command string
3637
cmd *exec.Cmd
38+
stdout io.Writer
39+
stderr io.Writer
40+
stdin io.Reader
41+
wg sync.WaitGroup
42+
}
43+
44+
// WithStdin sets the stdin reader
45+
func (b *Builder) WithStdin(r io.Reader) *Builder {
46+
b.stdin = r
47+
return b
48+
}
49+
50+
// WithStdout sets the stdout writer
51+
func (b *Builder) WithStdout(w io.Writer) *Builder {
52+
b.stdout = w
53+
return b
54+
}
55+
56+
// WithStderr sets the stderr writer
57+
func (b *Builder) WithStderr(w io.Writer) *Builder {
58+
b.stdout = w
59+
return b
3760
}
3861

3962
// WithArg sets the arg key, value pair that can be accessed via the 'arg' function
@@ -80,120 +103,94 @@ func (b *Builder) WithContext(context interface{}) *Builder {
80103
return b
81104
}
82105

83-
// Step is something you do with the processes streams
84-
type Step func(stdin io.WriteCloser, stdout io.ReadCloser, stderr io.ReadCloser) error
85-
86-
// Thenable is a fluent builder for chaining tasks
87-
type Thenable struct {
88-
steps []Step
89-
}
106+
var noop = func() error { return nil }
90107

91-
// Do creates a thenable
92-
func Do(f Step) *Thenable {
93-
return &Thenable{
94-
steps: []Step{f},
95-
}
96-
}
97-
98-
// Then adds another step
99-
func (t *Thenable) Then(then Step) *Thenable {
100-
t.steps = append(t.steps, then)
101-
return t
102-
}
103-
104-
// Done returns the final function
105-
func (t *Thenable) Done() Step {
106-
all := t.steps
107-
return func(stdin io.WriteCloser, stdout, stderr io.ReadCloser) error {
108-
for _, next := range all {
109-
if err := next(stdin, stdout, stderr); err != nil {
110-
return err
111-
}
112-
}
113-
return nil
114-
}
115-
}
116-
117-
// SendInput is a convenience function for writing to the exec process's stdin. When the function completes, the
118-
// stdin is closed.
119-
func SendInput(f func(io.WriteCloser) error) Step {
120-
return func(stdin io.WriteCloser, stdout, stderr io.ReadCloser) error {
121-
defer stdin.Close()
122-
return f(stdin)
123-
}
124-
}
125-
126-
// RedirectStdout sends stdout to given writer
127-
func RedirectStdout(out io.Writer) Step {
128-
return func(stdin io.WriteCloser, stdout, stderr io.ReadCloser) error {
129-
_, err := io.Copy(out, stdout)
130-
return err
131-
}
132-
}
133-
134-
// RedirectStderr sends stdout to given writer
135-
func RedirectStderr(out io.Writer) Step {
136-
return func(stdin io.WriteCloser, stdout, stderr io.ReadCloser) error {
137-
_, err := io.Copy(out, stderr)
138-
return err
139-
}
140-
}
141-
142-
// MergeOutput combines the stdout and stderr into the given stream
143-
func MergeOutput(out io.Writer) Step {
144-
return func(stdin io.WriteCloser, stdout, stderr io.ReadCloser) error {
145-
_, err := io.Copy(out, io.MultiReader(stdout, stderr))
146-
return err
147-
}
148-
}
149-
150-
// StartWithStreams starts the the process and then calls the function which allows
151-
// the streams to be wired. Calling the provided function blocks.
152-
func (b *Builder) StartWithStreams(f Step, args ...interface{}) error {
108+
// StartWithHandlers starts the cmd non blocking and calls the given handlers to process input / output
109+
func (b *Builder) StartWithHandlers(stdinFunc func(io.Writer) error,
110+
stdoutFunc func(io.Reader) error,
111+
stderrFunc func(io.Reader) error,
112+
args ...interface{}) error {
153113

154114
if err := b.Prepare(args...); err != nil {
155115
return err
156116
}
157117

158-
run := func() error { return nil }
159-
if f != nil {
118+
// There's a race between the input/output streams reads and cmd.wait() which
119+
// will close the pipes even while others are trying to read.
120+
// So we need to ensure that all the input/output are done before actually waiting
121+
// on the cmd to complete.
122+
// To do so, we use a waitgroup
123+
124+
handleInput := noop
125+
if stdinFunc != nil {
160126
pIn, err := b.cmd.StdinPipe()
161127
if err != nil {
162128
return err
163129
}
164130

131+
handleInput = func() error {
132+
defer func() {
133+
pIn.Close()
134+
b.wg.Done()
135+
}()
136+
return stdinFunc(pIn)
137+
}
138+
b.wg.Add(1)
139+
}
140+
141+
handleStdout := noop
142+
if stdoutFunc != nil {
165143
pOut, err := b.cmd.StdoutPipe()
166144
if err != nil {
167145
return err
168146
}
169-
147+
handleStdout = func() error {
148+
defer func() {
149+
pOut.Close()
150+
b.wg.Done()
151+
}()
152+
return stdoutFunc(pOut)
153+
}
154+
b.wg.Add(1)
155+
}
156+
handleStderr := noop
157+
if stderrFunc != nil {
170158
pErr, err := b.cmd.StderrPipe()
171159
if err != nil {
172160
return err
173161
}
174-
175-
run = func() error {
176-
return f(pIn, pOut, pErr)
162+
handleStderr = func() error {
163+
defer func() {
164+
pErr.Close()
165+
b.wg.Done()
166+
}()
167+
return stderrFunc(pErr)
177168
}
169+
b.wg.Add(1)
178170
}
179171

180172
if err := b.cmd.Start(); err != nil {
181173
return err
182174
}
183175

184-
return run()
176+
go handleStdout()
177+
go handleStderr()
178+
go handleInput()
179+
180+
return nil
185181
}
186182

187183
// Start does a Cmd.Start on the command
188184
func (b *Builder) Start(args ...interface{}) error {
189185
if err := b.Prepare(args...); err != nil {
190186
return err
191187
}
192-
return b.StartWithStreams(nil, args...)
188+
return b.StartWithHandlers(nil, nil, nil, args...)
193189
}
194190

195191
// Wait waits for the command to complete
196192
func (b *Builder) Wait() error {
193+
b.wg.Wait()
197194
return b.cmd.Wait()
198195
}
199196

@@ -260,6 +257,15 @@ func (b *Builder) Prepare(args ...interface{}) error {
260257
if b.inheritEnvs {
261258
b.cmd.Env = append(os.Environ(), b.envs...)
262259
}
260+
if b.stdin != nil {
261+
b.cmd.Stdin = b.stdin
262+
}
263+
if b.stdout != nil {
264+
b.cmd.Stdout = b.stdout
265+
}
266+
if b.stderr != nil {
267+
b.cmd.Stderr = b.stderr
268+
}
263269
return nil
264270
}
265271

pkg/util/exec/exec_test.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func TestBuilder(t *testing.T) {
5555

5656
}
5757

58-
func _TestRun(t *testing.T) {
58+
func TestRunDocker(t *testing.T) {
5959

6060
if SkipTests("docker") {
6161
t.SkipNow()
@@ -83,21 +83,24 @@ func _TestRun(t *testing.T) {
8383
dockerStop.Wait()
8484
}()
8585

86-
err = dateStream.InheritEnvs(true).WithArg("container", name).StartWithStreams(MergeOutput(os.Stderr))
86+
err = dateStream.InheritEnvs(true).WithArg("container", name).
87+
WithStdout(os.Stderr).
88+
WithStderr(os.Stderr).
89+
Start()
8790
require.NoError(t, err)
8891

8992
// testing with stdin
90-
err = Command("/bin/sh").InheritEnvs(true).StartWithStreams(
91-
Do(SendInput(
92-
func(stdin io.WriteCloser) error {
93+
err = Command("/bin/sh").InheritEnvs(true).
94+
WithStdout(os.Stderr).
95+
WithStderr(os.Stderr).
96+
StartWithHandlers(
97+
func(stdin io.Writer) error {
9398
T(100).Info("about to write to stdin")
9499
stdin.Write([]byte(`for i in $(seq 10); do echo $i; sleep 1; done`))
95100
T(100).Info("wrote to stdin")
96101
return nil
97-
})).Then(MergeOutput(os.Stderr)).Done(),
98-
)
102+
}, nil, nil)
99103
require.NoError(t, err)
100-
101104
}
102105

103106
func TestPipeline1(t *testing.T) {

0 commit comments

Comments
 (0)