Skip to content

Commit d5d6985

Browse files
committed
IOConnectorPair must close its connectors at the end
Before this change, the connectors were closed by the cancelletion of its associated (passed) context, which seems too early. Signed-off-by: Kazuyoshi Kato <[email protected]>
1 parent eba4e8b commit d5d6985

File tree

4 files changed

+80
-13
lines changed

4 files changed

+80
-13
lines changed

internal/vm/ioproxy.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,23 @@ type IOConnectorPair struct {
6767
}
6868

6969
func (connectorPair *IOConnectorPair) proxy(
70-
proc *vmProc,
70+
ctx context.Context,
7171
logger *logrus.Entry,
7272
timeoutAfterExit time.Duration,
7373
) (ioInitDone <-chan error, ioCopyDone <-chan error) {
7474
initDone := make(chan error, 2)
7575
copyDone := make(chan error)
7676

77+
ioCtx, ioCancel := context.WithCancel(context.Background())
78+
7779
// Start the initialization process. Any synchronous setup made by the connectors will
7880
// be completed after these lines. Async setup will be done once initDone is closed in
7981
// the goroutine below.
80-
readerResultCh := connectorPair.ReadConnector(proc.ctx, logger.WithField("direction", "read"))
81-
writerResultCh := connectorPair.WriteConnector(proc.ctx, logger.WithField("direction", "write"))
82+
readerResultCh := connectorPair.ReadConnector(ioCtx, logger.WithField("direction", "read"))
83+
writerResultCh := connectorPair.WriteConnector(ioCtx, logger.WithField("direction", "write"))
8284

8385
go func() {
86+
defer ioCancel()
8487
defer close(copyDone)
8588

8689
var reader io.ReadCloser
@@ -119,7 +122,7 @@ func (connectorPair *IOConnectorPair) proxy(
119122
// If the io streams close on their own before the timeout, the Close calls here
120123
// should just be no-ops.
121124
go func() {
122-
<-proc.ctx.Done()
125+
<-ctx.Done()
123126
time.AfterFunc(timeoutAfterExit, func() {
124127
logClose(logger, reader, writer)
125128
})
@@ -129,6 +132,7 @@ func (connectorPair *IOConnectorPair) proxy(
129132
defer logger.Debug("end copying io")
130133

131134
size, err := io.CopyBuffer(writer, reader, make([]byte, internal.DefaultBufferSize))
135+
logger.Debugf("copied %d", size)
132136
if err != nil {
133137
if strings.Contains(err.Error(), "use of closed network connection") ||
134138
strings.Contains(err.Error(), "file already closed") {
@@ -138,7 +142,6 @@ func (connectorPair *IOConnectorPair) proxy(
138142
}
139143
copyDone <- err
140144
}
141-
logger.Debugf("copied %d", size)
142145
defer logClose(logger, reader, writer)
143146
}()
144147

@@ -174,19 +177,20 @@ func (ioConnectorSet *ioConnectorSet) start(proc *vmProc) (ioInitDone <-chan err
174177
if ioConnectorSet.stdin != nil {
175178
// For Stdin only, provide 0 as the timeout to wait after the proc exits before closing IO streams.
176179
// There's no reason to send stdin data to a proc that's already dead.
177-
waitErrs(ioConnectorSet.stdin.proxy(proc, proc.logger.WithField("stream", "stdin"), 0))
180+
waitErrs(ioConnectorSet.stdin.proxy(proc.ctx, proc.logger.WithField("stream", "stdin"), 0))
181+
178182
} else {
179183
proc.logger.Debug("skipping proxy io for unset stdin")
180184
}
181185

182186
if ioConnectorSet.stdout != nil {
183-
waitErrs(ioConnectorSet.stdout.proxy(proc, proc.logger.WithField("stream", "stdout"), defaultIOFlushTimeout))
187+
waitErrs(ioConnectorSet.stdout.proxy(proc.ctx, proc.logger.WithField("stream", "stdout"), defaultIOFlushTimeout))
184188
} else {
185189
proc.logger.Debug("skipping proxy io for unset stdout")
186190
}
187191

188192
if ioConnectorSet.stderr != nil {
189-
waitErrs(ioConnectorSet.stderr.proxy(proc, proc.logger.WithField("stream", "stderr"), defaultIOFlushTimeout))
193+
waitErrs(ioConnectorSet.stderr.proxy(proc.ctx, proc.logger.WithField("stream", "stderr"), defaultIOFlushTimeout))
190194
} else {
191195
proc.logger.Debug("skipping proxy io for unset stderr")
192196
}

internal/vm/ioproxy_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"). You may
4+
// not use this file except in compliance with the License. A copy of the
5+
// License is located at
6+
//
7+
// http://aws.amazon.com/apache2.0/
8+
//
9+
// or in the "license" file accompanying this file. This file is distributed
10+
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
// express or implied. See the License for the specific language governing
12+
// permissions and limitations under the License.
13+
14+
package vm
15+
16+
import (
17+
"context"
18+
"io/ioutil"
19+
"os"
20+
"path/filepath"
21+
"testing"
22+
23+
"github.com/sirupsen/logrus"
24+
"github.com/stretchr/testify/assert"
25+
"github.com/stretchr/testify/require"
26+
)
27+
28+
func fileConnector(path string, flag int) IOConnector {
29+
return func(procCtx context.Context, logger *logrus.Entry) <-chan IOConnectorResult {
30+
returnCh := make(chan IOConnectorResult, 1)
31+
defer close(returnCh)
32+
33+
file, err := os.OpenFile(path, flag, 0600)
34+
returnCh <- IOConnectorResult{
35+
ReadWriteCloser: file,
36+
Err: err,
37+
}
38+
39+
return returnCh
40+
}
41+
}
42+
43+
func TestProxy(t *testing.T) {
44+
dir, err := ioutil.TempDir("", t.Name())
45+
require.NoError(t, err)
46+
defer os.RemoveAll(dir)
47+
48+
ctx := context.Background()
49+
content := "hello world"
50+
51+
err = ioutil.WriteFile(filepath.Join(dir, "input"), []byte(content), 0600)
52+
require.NoError(t, err)
53+
54+
pair := &IOConnectorPair{
55+
ReadConnector: fileConnector(filepath.Join(dir, "input"), os.O_RDONLY),
56+
WriteConnector: fileConnector(filepath.Join(dir, "output"), os.O_CREATE|os.O_WRONLY),
57+
}
58+
initCh, copyCh := pair.proxy(ctx, logrus.WithFields(logrus.Fields{}), 0)
59+
60+
assert.Nil(t, <-initCh)
61+
assert.Nil(t, <-copyCh)
62+
63+
bytes, err := ioutil.ReadFile(filepath.Join(dir, "output"))
64+
require.NoError(t, err)
65+
assert.Equal(t, content, string(bytes))
66+
}

internal/vm/task.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,6 @@ func (m *taskManager) monitorExit(proc *vmProc, taskService taskAPI.TaskService)
269269
ID: proc.taskID,
270270
ExecID: proc.execID,
271271
})
272-
273-
<-proc.ioCopyDone
274-
275272
proc.cancel()
276273

277274
if waitErr == context.Canceled {

runtime/service_integ_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ func TestExec_Isolated(t *testing.T) {
15891589
var taskStdout bytes.Buffer
15901590
var taskStderr bytes.Buffer
15911591

1592-
task, err := c.NewTask(ctx, cio.NewCreator(cio.WithStreams(nil, &taskStdout, &taskStderr)))
1592+
task, err := c.NewTask(ctx, cio.NewCreator(cio.WithStreams(os.Stdin, &taskStdout, &taskStderr)))
15931593
require.NoError(t, err, "failed to create task for container %s", c.ID())
15941594

15951595
taskExitCh, err := task.Wait(ctx)
@@ -1604,7 +1604,7 @@ func TestExec_Isolated(t *testing.T) {
16041604
taskExec, err := task.Exec(ctx, "exec", &specs.Process{
16051605
Args: []string{"/bin/date"},
16061606
Cwd: "/",
1607-
}, cio.NewCreator(cio.WithStreams(nil, &execStdout, &execStderr)))
1607+
}, cio.NewCreator(cio.WithStreams(os.Stdin, &execStdout, &execStderr)))
16081608
require.NoError(t, err)
16091609

16101610
execExitCh, err := taskExec.Wait(ctx)

0 commit comments

Comments
 (0)