Skip to content

Commit 524bb63

Browse files
authored
ztest: add context.Context param to RunInternal and RunScript (#6336)
1 parent d1f4eaa commit 524bb63

File tree

4 files changed

+20
-19
lines changed

4 files changed

+20
-19
lines changed

runtime/sam/expr/expr_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func testSuccessful(t *testing.T, e, input, expected string) {
1919
Input: &input,
2020
Output: expected + "\n",
2121
}
22-
if err := zt.RunInternal(); err != nil {
22+
if err := zt.RunInternal(t.Context()); err != nil {
2323
t.Fatal(err)
2424
}
2525
}
@@ -30,7 +30,7 @@ func testError(t *testing.T, e string, expectErr error) {
3030
SPQ: fmt.Sprintf("values %s", e),
3131
Error: expectErr.Error() + "\n",
3232
}
33-
if err := zt.RunInternal(); err != nil {
33+
if err := zt.RunInternal(t.Context()); err != nil {
3434
t.Fatal(err)
3535
}
3636
}

ztest/shell.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@ package ztest
22

33
import (
44
"bytes"
5+
"context"
56
"fmt"
67
"io"
78
"os"
89
"os/exec"
910
)
1011

11-
func RunShell(dir, bindir, script string, stdin io.Reader, useenvs, extraenvs []string) (string, string, error) {
12+
func RunShell(ctx context.Context, dir, bindir, script string, stdin io.Reader, useenvs, extraenvs []string) (string, string, error) {
1213
// "-e -o pipefile" ensures a test will fail if any command
1314
// fails unexpectedly.
14-
cmd := exec.Command("bash", "-e", "-o", "pipefail", "-c", script)
15+
cmd := exec.CommandContext(ctx, "bash", "-e", "-o", "pipefail", "-c", script)
1516
cmd.Dir = dir
1617
cmd.Env = []string{
1718
"AppData=" + dir, // For os.UserConfigDir on Windows.

ztest/ztest.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -330,42 +330,42 @@ func (z *ZTest) ShouldSkip(path string) string {
330330
return ""
331331
}
332332

333-
func (z *ZTest) RunScript(shellPath, testDir string, tempDir func() string) error {
333+
func (z *ZTest) RunScript(ctx context.Context, shellPath, testDir string, tempDir func() string) error {
334334
if err := z.check(); err != nil {
335335
return fmt.Errorf("bad yaml format: %w", err)
336336
}
337-
serr := runsh(shellPath, testDir, tempDir(), z)
337+
serr := runsh(ctx, shellPath, testDir, tempDir(), z)
338338
if !z.Vector {
339339
return serr
340340
}
341341
if serr != nil {
342342
serr = fmt.Errorf("=== sequence ===\n%w", serr)
343343
}
344-
verr := runsh(shellPath, testDir, tempDir(), z, "SUPER_VAM=1")
344+
verr := runsh(ctx, shellPath, testDir, tempDir(), z, "SUPER_VAM=1")
345345
if verr != nil {
346346
verr = fmt.Errorf("=== vector ===\n%w", verr)
347347
}
348348
return errors.Join(serr, verr)
349349
}
350350

351-
func (z *ZTest) RunInternal() error {
351+
func (z *ZTest) RunInternal(ctx context.Context) error {
352352
if err := z.check(); err != nil {
353353
return fmt.Errorf("bad yaml format: %w", err)
354354
}
355355
outputFlags := append([]string{"-f=sup", "-pretty=0"}, strings.Fields(z.OutputFlags)...)
356356
inputFlags := strings.Fields(z.InputFlags)
357357
if z.Vector {
358-
verr := z.diffInternal(runInternal(z.SPQ, z.Input, outputFlags, inputFlags, true))
358+
verr := z.diffInternal(runInternal(ctx, z.SPQ, z.Input, outputFlags, inputFlags, true))
359359
if verr != nil {
360360
verr = fmt.Errorf("=== vector ===\n%w", verr)
361361
}
362-
serr := z.diffInternal(runInternal(z.SPQ, z.Input, outputFlags, inputFlags, false))
362+
serr := z.diffInternal(runInternal(ctx, z.SPQ, z.Input, outputFlags, inputFlags, false))
363363
if serr != nil {
364364
serr = fmt.Errorf("=== sequence ===\n%w", serr)
365365
}
366366
return errors.Join(verr, serr)
367367
}
368-
return z.diffInternal(runInternal(z.SPQ, z.Input, outputFlags, inputFlags, false))
368+
return z.diffInternal(runInternal(ctx, z.SPQ, z.Input, outputFlags, inputFlags, false))
369369
}
370370

371371
func (z *ZTest) diffInternal(out string, err error) error {
@@ -390,9 +390,9 @@ func (z *ZTest) Run(t *testing.T, path, filename string) {
390390
}
391391
var err error
392392
if z.Script != "" {
393-
err = z.RunScript(path, filepath.Dir(filename), t.TempDir)
393+
err = z.RunScript(t.Context(), path, filepath.Dir(filename), t.TempDir)
394394
} else {
395-
err = z.RunInternal()
395+
err = z.RunInternal(t.Context())
396396
}
397397
if err != nil {
398398
t.Fatalf("%s: %s", filename, err)
@@ -417,7 +417,7 @@ func diffErr(name, expected, actual string) error {
417417
return fmt.Errorf("expected and actual %s differ:\n%s", name, diff)
418418
}
419419

420-
func runsh(path, testDir, tempDir string, zt *ZTest, extraEnv ...string) error {
420+
func runsh(ctx context.Context, path, testDir, tempDir string, zt *ZTest, extraEnv ...string) error {
421421
var stdin io.Reader
422422
for _, f := range zt.Inputs {
423423
b, _, err := f.load(testDir)
@@ -432,7 +432,7 @@ func runsh(path, testDir, tempDir string, zt *ZTest, extraEnv ...string) error {
432432
return err
433433
}
434434
}
435-
stdout, stderr, err := RunShell(tempDir, path, zt.Script, stdin, zt.Env, extraEnv)
435+
stdout, stderr, err := RunShell(ctx, tempDir, path, zt.Script, stdin, zt.Env, extraEnv)
436436
if err != nil {
437437
return fmt.Errorf("script failed: %w\n=== stdout ===\n%s=== stderr ===\n%s",
438438
err, stdout, stderr)
@@ -468,7 +468,7 @@ func runsh(path, testDir, tempDir string, zt *ZTest, extraEnv ...string) error {
468468
// runInternal runs query over input and returns the output. input
469469
// may be in any format recognized by "super -i auto" and may be gzip-compressed.
470470
// outputFlags may contain any flags accepted by cli/outputflags.Flags.
471-
func runInternal(query string, input *string, outputFlags, inputFlags []string, vector bool) (string, error) {
471+
func runInternal(ctx context.Context, query string, input *string, outputFlags, inputFlags []string, vector bool) (string, error) {
472472
ast, err := parser.ParseQuery(query)
473473
if err != nil {
474474
return "", err
@@ -496,7 +496,7 @@ func runInternal(query string, input *string, outputFlags, inputFlags []string,
496496
if vector {
497497
env.SetUseVAM()
498498
}
499-
q, err := runtime.CompileQuery(context.Background(), sctx, compiler.NewCompilerWithEnv(env), ast, readers)
499+
q, err := runtime.CompileQuery(ctx, sctx, compiler.NewCompilerWithEnv(env), ast, readers)
500500
if err != nil {
501501
return "", err
502502
}

ztest/ztest_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ func TestRunScript(t *testing.T) {
4040
{Name: "testdirfile"},
4141
{Name: "testdirfile2", Source: "testdirfile"},
4242
},
43-
}).RunScript("", testDir, t.TempDir)
43+
}).RunScript(t.Context(), "", testDir, t.TempDir)
4444
assert.NoError(t, err)
4545
})
4646
t.Run("error", func(t *testing.T) {
4747
err := (&ZTest{
4848
Script: "echo 1; echo 2 >&2; exit 3",
4949
Outputs: []File{},
50-
}).RunScript("", "", func() string { return "" })
50+
}).RunScript(t.Context(), "", "", func() string { return "" })
5151
assert.EqualError(t, err, "script failed: exit status 3\n=== stdout ===\n1\n=== stderr ===\n2\n")
5252
})
5353
}

0 commit comments

Comments
 (0)