Skip to content

Commit 477664a

Browse files
authored
add coverage and tests to utils/utils.go (#31)
1 parent 67e7ffe commit 477664a

File tree

4 files changed

+151
-6
lines changed

4 files changed

+151
-6
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,9 @@ doc/*.xml
1212
# packages
1313
sshproxy_*.tar.gz
1414

15+
# coverage outputs
16+
test/coverage.out
17+
test/coverage.html
18+
1519
# benchmark results
1620
benchmarks/results

Makefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ check:
9090
$(GO) vet $(TEST)
9191
staticcheck ./...
9292
staticcheck $(TEST)
93-
$(GO) test -failfast -race -count=1 -timeout=10s ./...
93+
$(GO) test -coverprofile=test/coverage.out -failfast -race -count=1 -timeout=10s ./...
94+
$(GO) tool cover -html=test/coverage.out -o test/coverage.html
9495

9596
test:
9697
cd test && bash ./run.sh
@@ -100,6 +101,6 @@ benchmark:
100101
$(GO) test -failfast -race -count=6 -bench=. -run=^# -benchmem ./... | tee benchmarks/results/$(DATE)-$(COMMIT)
101102

102103
clean:
103-
rm -f $(EXE) $(MANDOC) doc/*.xml sshproxy_*.tar.gz
104+
rm -f $(EXE) $(MANDOC) doc/*.xml sshproxy_*.tar.gz test/coverage.*
104105

105106
.PHONY: all exe doc install install-doc-man install-binaries package fmt get-deps check test benchmark clean

pkg/utils/utils.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@ func SplitHostPort(hostport string) (string, string, error) {
5050
return host, strconv.Itoa(portNum), nil
5151
}
5252

53-
// GetGroupUser returns a map of group memberships for the specifised user.
53+
// Mocking user.Lookup and user.LookupGroupId for testing.
54+
var userCurrent = user.Current
55+
var userLookup = user.Lookup
56+
var userLookupGroupId = user.LookupGroupId
57+
58+
// GetGroupUser returns a map of group memberships for the specified user.
5459
//
5560
// It can be used to quickly check if a user is in a specified group.
5661
func GetGroupUser(u *user.User) (map[string]bool, error) {
@@ -61,7 +66,7 @@ func GetGroupUser(u *user.User) (map[string]bool, error) {
6166

6267
groups := make(map[string]bool)
6368
for _, gid := range groupids {
64-
g, err := user.LookupGroupId(gid)
69+
g, err := userLookupGroupId(gid)
6570
if err != nil {
6671
return nil, err
6772
}
@@ -76,7 +81,7 @@ func GetGroupUser(u *user.User) (map[string]bool, error) {
7681
//
7782
// It can be used to quickly check if a user is in a specified group.
7883
func GetGroups() (map[string]bool, error) {
79-
u, err := user.Current()
84+
u, err := userCurrent()
8085
if err != nil {
8186
return nil, err
8287
}
@@ -93,7 +98,7 @@ func GetGroups() (map[string]bool, error) {
9398
//
9499
// It can be used to quickly check if a user is in a specified group.
95100
func GetGroupList(username string) (map[string]bool, error) {
96-
u, err := user.Lookup(username)
101+
u, err := userLookup(username)
97102
if err != nil {
98103
return nil, err
99104
}

pkg/utils/utils_test.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@ package utils
1212

1313
import (
1414
"errors"
15+
"fmt"
16+
"os/user"
1517
"reflect"
18+
"sort"
19+
"strings"
1620
"testing"
1721
"time"
1822
)
@@ -97,6 +101,116 @@ func BenchmarkSplitHostPort(b *testing.B) {
97101
}
98102
}
99103

104+
func mockUserCurrent() (*user.User, error) {
105+
return mockUserLookup("testuser")
106+
}
107+
108+
func mockUserLookup(username string) (*user.User, error) {
109+
var u user.User
110+
if username == "root" {
111+
u.Uid = "0"
112+
u.Gid = "0"
113+
} else if username == "testuser" {
114+
u.Uid = "1000"
115+
u.Gid = "1000"
116+
} else if username == "userwithnogroupid" {
117+
u.Uid = "1001"
118+
} else if username == "userwithinvalidgroup" {
119+
u.Uid = "1002"
120+
u.Gid = "1002"
121+
} else {
122+
return &u, fmt.Errorf("user: unknown user %s", username)
123+
}
124+
u.Username = username
125+
return &u, nil
126+
}
127+
128+
func mockUserLookupGroupId(gid string) (*user.Group, error) {
129+
var g user.Group
130+
g.Gid = gid
131+
if gid == "0" {
132+
g.Name = "root"
133+
} else if gid == "1000" {
134+
g.Name = "testgroup"
135+
} else {
136+
return &g, fmt.Errorf("group: unknown group ID %s", gid)
137+
}
138+
return &g, nil
139+
}
140+
141+
func BenchmarkGetGroupUser(b *testing.B) {
142+
current, _ := userCurrent()
143+
root, _ := userLookup("root")
144+
for _, tt := range []*user.User{current, root} {
145+
b.Run(tt.Username, func(b *testing.B) {
146+
for i := 0; i < b.N; i++ {
147+
GetGroupUser(tt)
148+
}
149+
})
150+
}
151+
}
152+
153+
func TestGetGroups(t *testing.T) {
154+
userCurrent = mockUserCurrent
155+
userLookupGroupId = mockUserLookupGroupId
156+
groups, err := GetGroups()
157+
if err != nil {
158+
t.Errorf("GetGroups error = '%v', want nil", err)
159+
} else if len(groups) < 1 {
160+
t.Error("GetGroups must return at least one group")
161+
}
162+
}
163+
164+
func BenchmarkGetGroups(b *testing.B) {
165+
for i := 0; i < b.N; i++ {
166+
GetGroups()
167+
}
168+
}
169+
170+
var getGroupListTests = []struct {
171+
user, groups, err string
172+
}{
173+
{"root", "root", ""},
174+
{"testuser", "testgroup", ""},
175+
{"userwithnogroupid", "", "user: list groups for userwithnogroupid: invalid gid \"\""},
176+
{"userwithinvalidgroup", "nonexistentgroup", "group: unknown group ID 1002"},
177+
{"nonexistentuser", "nonexistentgroup", "user: unknown user nonexistentuser"},
178+
}
179+
180+
func TestGetGroupList(t *testing.T) {
181+
userLookup = mockUserLookup
182+
userLookupGroupId = mockUserLookupGroupId
183+
for _, tt := range getGroupListTests {
184+
groups, err := GetGroupList(tt.user)
185+
if err != nil {
186+
if fmt.Sprintf("%s", err) != tt.err {
187+
t.Errorf("GetGroupList error = '%v', want '%v'", err, tt.err)
188+
}
189+
} else {
190+
g := make([]string, 0, len(groups))
191+
for group := range groups {
192+
g = append(g, group)
193+
}
194+
sort.Strings(g)
195+
if strings.Join(g, " ") != tt.groups {
196+
t.Errorf("GetGroupList groups = %v, want %v", strings.Join(g, " "), tt.groups)
197+
}
198+
}
199+
}
200+
}
201+
202+
func BenchmarkGetGroupList(b *testing.B) {
203+
for _, tt := range getGroupListTests {
204+
if tt.err == "" {
205+
b.Run(tt.user, func(b *testing.B) {
206+
for i := 0; i < b.N; i++ {
207+
GetGroupList(tt.user)
208+
}
209+
})
210+
}
211+
}
212+
}
213+
100214
func mockNetLookupHost(host string) ([]string, error) {
101215
if host == "err" {
102216
return nil, errors.New("LookupHost error")
@@ -106,6 +220,17 @@ func mockNetLookupHost(host string) ([]string, error) {
106220
return []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"}, nil
107221
}
108222

223+
func BenchmarkNormalizeHostPort(b *testing.B) {
224+
netLookupHost = mockNetLookupHost
225+
for _, tt := range []string{"127.0.0.1", "127.0.0.1:123", "server1", "server1:123", "host:port:invalid", "err"} {
226+
b.Run(tt, func(b *testing.B) {
227+
for i := 0; i < b.N; i++ {
228+
normalizeHostPort(tt)
229+
}
230+
})
231+
}
232+
}
233+
109234
var matchSourceTests = []struct {
110235
source, sshdHostport string
111236
want bool
@@ -160,6 +285,16 @@ var matchSourceTests = []struct {
160285
"server2",
161286
true,
162287
},
288+
{
289+
"127.0.0.1:22",
290+
"127.0.0.1:122",
291+
false,
292+
},
293+
{
294+
"127.0.0.1:22",
295+
"127.0.0.2:22",
296+
false,
297+
},
163298
}
164299

165300
func TestMatchSource(t *testing.T) {

0 commit comments

Comments
 (0)