Skip to content

Commit 3b5b4ea

Browse files
authored
Merge pull request #13 from ActiveState/DX-2901
Reimplement process exit expect failure
2 parents 45f7444 + e176b0b commit 3b5b4ea

File tree

9 files changed

+65
-72
lines changed

9 files changed

+65
-72
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ name: unit-tests
22

33
on:
44
push:
5-
branches: [ master ]
5+
branches: [ master, v2-wip ]
66
pull_request:
7-
branches: [ master ]
7+
branches: [ master, v2-wip]
88

99
jobs:
1010

expect.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func (tt *TermTest) ExpectCustom(consumer consumer, opts ...SetExpectOpt) (rerr
8686
return fmt.Errorf("could not create expect options: %w", err)
8787
}
8888

89-
cons, err := tt.outputProducer.addConsumer(tt, consumer, expectOpts.ToConsumerOpts()...)
89+
cons, err := tt.outputProducer.addConsumer(consumer, expectOpts.ToConsumerOpts()...)
9090
if err != nil {
9191
return fmt.Errorf("could not add consumer: %w", err)
9292
}
@@ -180,11 +180,11 @@ func (tt *TermTest) expectExitCode(exitCode int, match bool, opts ...SetExpectOp
180180
select {
181181
case <-time.After(timeoutV):
182182
return fmt.Errorf("after %s: %w", timeoutV, TimeoutError)
183-
case state := <-tt.Exited(false): // do not wait for unread output since it's not read by this select{}
184-
if state.Err != nil && (state.ProcessState == nil || state.ProcessState.ExitCode() == 0) {
185-
return fmt.Errorf("cmd wait failed: %w", state.Err)
183+
case err := <-waitChan(tt.cmd.Wait):
184+
if err != nil && (tt.cmd.ProcessState == nil || tt.cmd.ProcessState.ExitCode() == 0) {
185+
return fmt.Errorf("cmd wait failed: %w", err)
186186
}
187-
if err := tt.assertExitCode(state.ProcessState.ExitCode(), exitCode, match); err != nil {
187+
if err := tt.assertExitCode(tt.cmd.ProcessState.ExitCode(), exitCode, match); err != nil {
188188
return err
189189
}
190190
}

expect_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func Test_ExpectCustom(t *testing.T) {
8585
[]SetExpectOpt{OptExpectTimeout(time.Second)},
8686
},
8787
"",
88-
TimeoutError,
88+
PtyEOF,
8989
},
9090
{
9191
"Custom error",
@@ -167,7 +167,7 @@ func Test_ExpectCustom_Cmd(t *testing.T) {
167167
},
168168
[]SetExpectOpt{OptExpectTimeout(time.Second)},
169169
},
170-
TimeoutError,
170+
PtyEOF,
171171
},
172172
{
173173
"Custom error",
@@ -194,7 +194,7 @@ func Test_ExpectCustom_Cmd(t *testing.T) {
194194
}
195195

196196
func Test_Expect_Timeout(t *testing.T) {
197-
tt := newTermTest(t, exec.Command("bash", "-c", "echo HELLO"), false)
197+
tt := newTermTest(t, exec.Command("bash", "-c", "echo HELLO && sleep 1"), false)
198198
durations := []time.Duration{
199199
100 * time.Millisecond,
200200
200 * time.Millisecond,

helpers.go

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"bytes"
55
"errors"
66
"os"
7-
"os/exec"
87
"strings"
98
"time"
109
)
@@ -23,20 +22,10 @@ type cmdExit struct {
2322
Err error
2423
}
2524

26-
// waitForCmdExit turns process.wait() into a channel so that it can be used within a select{} statement
27-
func waitForCmdExit(cmd *exec.Cmd) chan *cmdExit {
28-
exit := make(chan *cmdExit, 1)
29-
go func() {
30-
err := cmd.Wait()
31-
exit <- &cmdExit{ProcessState: cmd.ProcessState, Err: err}
32-
}()
33-
return exit
34-
}
35-
3625
func waitChan[T any](wait func() T) chan T {
37-
done := make(chan T)
26+
done := make(chan T, 1)
3827
go func() {
39-
done <- wait()
28+
wait()
4029
close(done)
4130
}()
4231
return done

outputconsumer.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ type outputConsumer struct {
1515
opts *OutputConsumerOpts
1616
isalive bool
1717
mutex *sync.Mutex
18-
tt *TermTest
1918
}
2019

2120
type OutputConsumerOpts struct {
@@ -37,7 +36,7 @@ func OptsConsTimeout(timeout time.Duration) func(o *OutputConsumerOpts) {
3736
}
3837
}
3938

40-
func newOutputConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) *outputConsumer {
39+
func newOutputConsumer(consume consumer, opts ...SetConsOpt) *outputConsumer {
4140
oc := &outputConsumer{
4241
consume: consume,
4342
opts: &OutputConsumerOpts{
@@ -47,7 +46,6 @@ func newOutputConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) *outp
4746
waiter: make(chan error, 1),
4847
isalive: true,
4948
mutex: &sync.Mutex{},
50-
tt: tt,
5149
}
5250

5351
for _, optSetter := range opts {
@@ -83,6 +81,23 @@ func (e *outputConsumer) Report(buffer []byte) (int, error) {
8381
return pos, err
8482
}
8583

84+
type errConsumerStopped struct {
85+
reason error
86+
}
87+
88+
func (e errConsumerStopped) Error() string {
89+
return fmt.Sprintf("consumer stopped, reason: %s", e.reason)
90+
}
91+
92+
func (e errConsumerStopped) Unwrap() error {
93+
return e.reason
94+
}
95+
96+
func (e *outputConsumer) Stop(reason error) {
97+
e.opts.Logger.Printf("stopping consumer, reason: %s\n", reason)
98+
e.waiter <- errConsumerStopped{reason}
99+
}
100+
86101
func (e *outputConsumer) wait() error {
87102
e.opts.Logger.Println("started waiting")
88103
defer e.opts.Logger.Println("stopped waiting")
@@ -103,11 +118,5 @@ func (e *outputConsumer) wait() error {
103118
e.mutex.Lock()
104119
e.opts.Logger.Println("Encountered timeout")
105120
return fmt.Errorf("after %s: %w", e.opts.Timeout, TimeoutError)
106-
case state := <-e.tt.Exited(true): // allow for output to be read first by first case in this select{}
107-
e.mutex.Lock()
108-
if state.Err != nil {
109-
e.opts.Logger.Println("Encountered error waiting for process to exit: %s\n", state.Err.Error())
110-
}
111-
return fmt.Errorf("process exited (status: %d)", state.ProcessState.ExitCode())
112121
}
113122
}

outputproducer.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ func (o *outputProducer) listen(r io.Reader, w io.Writer, appendBuffer func([]by
5454
for {
5555
o.opts.Logger.Println("listen: loop")
5656
if err := o.processNextRead(br, w, appendBuffer, size); err != nil {
57-
if errors.Is(err, ptyEOF) {
58-
o.opts.Logger.Println("listen: reached EOF")
57+
if errors.Is(err, PtyEOF) {
5958
return nil
6059
} else {
6160
return fmt.Errorf("could not poll reader: %w", err)
@@ -64,7 +63,7 @@ func (o *outputProducer) listen(r io.Reader, w io.Writer, appendBuffer func([]by
6463
}
6564
}
6665

67-
var ptyEOF = errors.New("pty closed")
66+
var PtyEOF = errors.New("pty closed")
6867

6968
func (o *outputProducer) processNextRead(r io.Reader, w io.Writer, appendBuffer func([]byte, bool) error, size int) error {
7069
o.opts.Logger.Printf("processNextRead started with size: %d\n", size)
@@ -78,6 +77,7 @@ func (o *outputProducer) processNextRead(r io.Reader, w io.Writer, appendBuffer
7877
pathError := &fs.PathError{}
7978
if errors.Is(errRead, fs.ErrClosed) || errors.Is(errRead, io.EOF) || (runtime.GOOS == "linux" && errors.As(errRead, &pathError)) {
8079
isEOF = true
80+
o.opts.Logger.Println("reached EOF")
8181
}
8282
}
8383

@@ -96,7 +96,8 @@ func (o *outputProducer) processNextRead(r io.Reader, w io.Writer, appendBuffer
9696

9797
if errRead != nil {
9898
if isEOF {
99-
return errors.Join(errRead, ptyEOF)
99+
o.closeConsumers(PtyEOF)
100+
return errors.Join(errRead, PtyEOF)
100101
}
101102
return fmt.Errorf("could not read pty output: %w", errRead)
102103
}
@@ -194,6 +195,19 @@ func (o *outputProducer) processDirtyOutput(output []byte, cursorPos int, cleanU
194195
return append(append(alreadyCleanedOutput, processedOutput...), unprocessedOutput...), processedCursorPos, newCleanUptoPos, nil
195196
}
196197

198+
func (o *outputProducer) closeConsumers(reason error) {
199+
o.opts.Logger.Println("closing consumers")
200+
defer o.opts.Logger.Println("closed consumers")
201+
202+
o.mutex.Lock()
203+
defer o.mutex.Unlock()
204+
205+
for n := 0; n < len(o.consumers); n++ {
206+
o.consumers[n].Stop(reason)
207+
o.consumers = append(o.consumers[:n], o.consumers[n+1:]...)
208+
}
209+
}
210+
197211
func (o *outputProducer) flushConsumers() error {
198212
o.opts.Logger.Println("flushing consumers")
199213
defer o.opts.Logger.Println("flushed consumers")
@@ -238,12 +252,12 @@ func (o *outputProducer) flushConsumers() error {
238252
return nil
239253
}
240254

241-
func (o *outputProducer) addConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) (*outputConsumer, error) {
255+
func (o *outputProducer) addConsumer(consume consumer, opts ...SetConsOpt) (*outputConsumer, error) {
242256
o.opts.Logger.Printf("adding consumer")
243257
defer o.opts.Logger.Printf("added consumer")
244258

245259
opts = append(opts, OptConsInherit(o.opts))
246-
listener := newOutputConsumer(tt, consume, opts...)
260+
listener := newOutputConsumer(consume, opts...)
247261
o.consumers = append(o.consumers, listener)
248262

249263
if err := o.flushConsumers(); err != nil {

termtest.go

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ type TermTest struct {
2323
ptmx pty.Pty
2424
outputProducer *outputProducer
2525
listenError chan error
26+
waitError chan error
2627
opts *Opts
27-
exited *cmdExit
2828
}
2929

3030
type ErrorHandler func(*TermTest, error) error
@@ -79,6 +79,7 @@ func New(cmd *exec.Cmd, opts ...SetOpt) (*TermTest, error) {
7979
cmd: cmd,
8080
outputProducer: newOutputProducer(optv),
8181
listenError: make(chan error, 1),
82+
waitError: make(chan error, 1),
8283
opts: optv,
8384
}
8485

@@ -228,6 +229,7 @@ func (tt *TermTest) start() (rerr error) {
228229
tt.term = vt10x.New(vt10x.WithWriter(ptmx), vt10x.WithSize(tt.opts.Cols, tt.opts.Rows))
229230

230231
// Start listening for output
232+
// We use a waitgroup here to ensure the listener is active before consumers are attached.
231233
wg := &sync.WaitGroup{}
232234
wg.Add(1)
233235
go func() {
@@ -236,12 +238,18 @@ func (tt *TermTest) start() (rerr error) {
236238
err := tt.outputProducer.Listen(tt.ptmx, tt.term)
237239
tt.listenError <- err
238240
}()
239-
wg.Wait()
240241

241242
go func() {
242-
tt.exited = <-waitForCmdExit(tt.cmd)
243+
// We start waiting right away, because on Windows the PTY isn't closed until the process exits, which in turn
244+
// can't happen unless we've told the pty we're ready for it to close.
245+
// This of course isn't ideal, but until the pty library fixes the cross-platform inconsistencies we have to
246+
// work around these limitations.
247+
defer tt.opts.Logger.Printf("waitIndefinitely finished")
248+
tt.waitError <- tt.waitIndefinitely()
243249
}()
244250

251+
wg.Wait()
252+
245253
return nil
246254
}
247255

@@ -252,13 +260,8 @@ func (tt *TermTest) Wait(timeout time.Duration) (rerr error) {
252260
tt.opts.Logger.Println("wait called")
253261
defer tt.opts.Logger.Println("wait closed")
254262

255-
errc := make(chan error, 1)
256-
go func() {
257-
errc <- tt.WaitIndefinitely()
258-
}()
259-
260263
select {
261-
case err := <-errc:
264+
case err := <-tt.waitError:
262265
// WaitIndefinitely already invokes the expect error handler
263266
return err
264267
case <-time.After(timeout):
@@ -324,28 +327,6 @@ func (tt *TermTest) SendCtrlC() {
324327
tt.Send(string([]byte{0x03})) // 0x03 is ASCII character for ^C
325328
}
326329

327-
// Exited returns a channel that sends the given termtest's command cmdExit info when available.
328-
// This can be used within a select{} statement.
329-
// If waitExtra is given, waits a little bit before sending cmdExit info. This allows any fellow
330-
// switch cases with output consumers to handle unprocessed stdout. If there are no such cases
331-
// (e.g. ExpectExit(), where we want to catch an exit ASAP), waitExtra should be false.
332-
func (tt *TermTest) Exited(waitExtra bool) chan *cmdExit {
333-
return waitChan(func() *cmdExit {
334-
ticker := time.NewTicker(processExitPollInterval)
335-
for {
336-
select {
337-
case <-ticker.C:
338-
if tt.exited != nil {
339-
if waitExtra { // allow sibling output consumer cases to handle their output
340-
time.Sleep(processExitExtraWait)
341-
}
342-
return tt.exited
343-
}
344-
}
345-
}
346-
})
347-
}
348-
349330
func (tt *TermTest) errorHandler(rerr *error) {
350331
err := *rerr
351332
if err == nil {

termtest_other.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func syscallErrorCode(err error) int {
1212
return -1
1313
}
1414

15-
func (tt *TermTest) WaitIndefinitely() error {
15+
func (tt *TermTest) waitIndefinitely() error {
1616
tt.opts.Logger.Println("WaitIndefinitely called")
1717
defer tt.opts.Logger.Println("WaitIndefinitely closed")
1818

termtest_windows.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ func syscallErrorCode(err error) int {
1616
return 0
1717
}
1818

19-
// WaitIndefinitely on Windows has to work around a Windows PTY bug where the PTY will NEVER exit by itself:
19+
// waitIndefinitely on Windows has to work around a Windows PTY bug where the PTY will NEVER exit by itself:
2020
// https://github.com/photostorm/pty/issues/3
2121
// Instead we wait for the process itself to exit, and after a grace period will shut down the pty.
22-
func (tt *TermTest) WaitIndefinitely() error {
22+
func (tt *TermTest) waitIndefinitely() error {
2323
tt.opts.Logger.Println("WaitIndefinitely called")
2424
defer tt.opts.Logger.Println("WaitIndefinitely closed")
2525

0 commit comments

Comments
 (0)