Skip to content

Commit 1b3f031

Browse files
committed
chore: add version flag test
1 parent 76f83d1 commit 1b3f031

File tree

2 files changed

+164
-4
lines changed

2 files changed

+164
-4
lines changed

pkg/version/flag.go

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@ package version
1010

1111
import (
1212
"fmt"
13+
"io"
1314
"os"
1415
"strconv"
16+
"strings"
1517

1618
flag "github.com/spf13/pflag"
19+
"k8s.io/component-base/version"
1720
)
1821

1922
type versionValue int
@@ -39,6 +42,16 @@ func (v *versionValue) Set(s string) error {
3942
*v = VersionRaw
4043
return nil
4144
}
45+
46+
if strings.HasPrefix(s, "v") {
47+
err := version.SetDynamicVersion(s)
48+
if err == nil {
49+
value, _ := strconv.Atoi(s)
50+
*v = versionValue(value)
51+
}
52+
return err
53+
}
54+
4255
boolVal, err := strconv.ParseBool(s)
4356
if boolVal {
4457
*v = VersionTrue
@@ -83,14 +96,21 @@ func AddFlags(fs *flag.FlagSet) {
8396
fs.AddFlag(flag.Lookup(versionFlagName))
8497
}
8598

99+
// variables for unit testing PrintAndExitIfRequested
100+
var (
101+
appName = "onex-apiserver"
102+
output = io.Writer(os.Stdout)
103+
exit = os.Exit
104+
)
105+
86106
// PrintAndExitIfRequested will check if the -version flag was passed
87107
// and, if so, print the version and exit.
88108
func PrintAndExitIfRequested(appName string) {
89109
if *versionFlag == VersionRaw {
90-
fmt.Printf("%s\n", Get().Text())
91-
os.Exit(0)
110+
fmt.Fprintf(output, "%s\n", Get().Text())
111+
exit(0)
92112
} else if *versionFlag == VersionTrue {
93-
fmt.Printf("%s %s\n", appName, Get().GitVersion)
94-
os.Exit(0)
113+
fmt.Fprintf(output, "%s %s\n", appName, Get().GitVersion)
114+
exit(0)
95115
}
96116
}

pkg/version/flag_test.go

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package version
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"strings"
7+
"testing"
8+
9+
"github.com/spf13/pflag"
10+
"k8s.io/component-base/version"
11+
)
12+
13+
func TestVersionFlag(t *testing.T) {
14+
initialFlagValue := int(*versionFlag)
15+
initialVersion := Get()
16+
17+
testcases := []struct {
18+
name string
19+
flags []string
20+
expectError string
21+
expectExit bool
22+
expectPrintVersion string
23+
expectGitVersion string
24+
}{
25+
{
26+
name: "no flag",
27+
flags: []string{},
28+
expectGitVersion: initialVersion.GitVersion,
29+
},
30+
{
31+
name: "false",
32+
flags: []string{"--version=false"},
33+
expectGitVersion: initialVersion.GitVersion,
34+
},
35+
36+
{
37+
name: "valueless",
38+
flags: []string{"--version"},
39+
expectGitVersion: initialVersion.GitVersion,
40+
expectExit: true,
41+
expectPrintVersion: appName + " " + initialVersion.GitVersion,
42+
},
43+
{
44+
name: "true",
45+
flags: []string{"--version=true"},
46+
expectGitVersion: initialVersion.GitVersion,
47+
expectExit: true,
48+
expectPrintVersion: appName + " " + initialVersion.GitVersion,
49+
},
50+
{
51+
name: "raw",
52+
flags: []string{"--version=raw"},
53+
expectGitVersion: initialVersion.GitVersion,
54+
expectExit: true,
55+
expectPrintVersion: fmt.Sprintf("%s", strings.TrimSpace(initialVersion.Text())),
56+
},
57+
{
58+
name: "truthy",
59+
flags: []string{"--version=T"},
60+
expectGitVersion: initialVersion.GitVersion,
61+
expectExit: true,
62+
expectPrintVersion: appName + " " + initialVersion.GitVersion,
63+
},
64+
{
65+
name: "override",
66+
flags: []string{"--version=v0.0.0-custom"},
67+
expectGitVersion: "v0.0.0-custom",
68+
},
69+
{
70+
name: "invalid override semver",
71+
flags: []string{"--version=vX"},
72+
expectError: `could not parse "vX"`,
73+
},
74+
{
75+
name: "invalid override major",
76+
flags: []string{"--version=v1.0.0"},
77+
expectError: `must match major/minor/patch`,
78+
},
79+
{
80+
name: "invalid override minor",
81+
flags: []string{"--version=v0.1.0"},
82+
expectError: `must match major/minor/patch`,
83+
},
84+
{
85+
name: "invalid override patch",
86+
flags: []string{"--version=v0.0.1"},
87+
expectError: `must match major/minor/patch`,
88+
},
89+
}
90+
91+
for _, tc := range testcases {
92+
t.Run(tc.name, func(t *testing.T) {
93+
94+
originalOutput := output
95+
originalExit := exit
96+
97+
outputBuffer := &bytes.Buffer{}
98+
output = outputBuffer
99+
exitCalled := false
100+
exit = func(code int) { exitCalled = true }
101+
102+
t.Cleanup(func() {
103+
output = originalOutput
104+
exit = originalExit
105+
*versionFlag = versionValue(initialFlagValue)
106+
err := version.SetDynamicVersion(initialVersion.GitVersion)
107+
if err != nil {
108+
t.Fatal(err)
109+
}
110+
})
111+
112+
fs := pflag.NewFlagSet("test", pflag.ContinueOnError)
113+
AddFlags(fs)
114+
err := fs.Parse(tc.flags)
115+
if tc.expectError != "" {
116+
if err == nil {
117+
t.Fatal("expected error, got none")
118+
}
119+
if !strings.Contains(err.Error(), tc.expectError) {
120+
t.Fatalf("expected error containing %q, got %q", tc.expectError, err.Error())
121+
}
122+
return
123+
} else if err != nil {
124+
t.Fatalf("unexpected parse error: %v", err)
125+
}
126+
127+
if e, a := tc.expectGitVersion, version.Get().GitVersion; e != a {
128+
t.Fatalf("gitversion: expected %v, got %v", e, a)
129+
}
130+
131+
PrintAndExitIfRequested(appName)
132+
if e, a := tc.expectExit, exitCalled; e != a {
133+
t.Fatalf("exit(): expected %v, got %v", e, a)
134+
}
135+
if e, a := tc.expectPrintVersion, strings.TrimSpace(outputBuffer.String()); e != a {
136+
t.Fatalf("print version: expected %v, got %v", e, a)
137+
}
138+
})
139+
}
140+
}

0 commit comments

Comments
 (0)