@@ -12,7 +12,11 @@ package utils
1212
1313import (
1414 "errors"
15+ "fmt"
16+ "os/user"
1517 "reflect"
18+ "sort"
19+ "strings"
1620 "testing"
1721 "time"
1822)
@@ -97,6 +101,116 @@ func BenchmarkSplitHostPort(b *testing.B) {
97101 }
98102}
99103
104+ func mockUserCurrent () (* user.User , error ) {
105+ return mockUserLookup ("testuser" )
106+ }
107+
108+ func mockUserLookup (username string ) (* user.User , error ) {
109+ var u user.User
110+ if username == "root" {
111+ u .Uid = "0"
112+ u .Gid = "0"
113+ } else if username == "testuser" {
114+ u .Uid = "1000"
115+ u .Gid = "1000"
116+ } else if username == "userwithnogroupid" {
117+ u .Uid = "1001"
118+ } else if username == "userwithinvalidgroup" {
119+ u .Uid = "1002"
120+ u .Gid = "1002"
121+ } else {
122+ return & u , fmt .Errorf ("user: unknown user %s" , username )
123+ }
124+ u .Username = username
125+ return & u , nil
126+ }
127+
128+ func mockUserLookupGroupId (gid string ) (* user.Group , error ) {
129+ var g user.Group
130+ g .Gid = gid
131+ if gid == "0" {
132+ g .Name = "root"
133+ } else if gid == "1000" {
134+ g .Name = "testgroup"
135+ } else {
136+ return & g , fmt .Errorf ("group: unknown group ID %s" , gid )
137+ }
138+ return & g , nil
139+ }
140+
141+ func BenchmarkGetGroupUser (b * testing.B ) {
142+ current , _ := userCurrent ()
143+ root , _ := userLookup ("root" )
144+ for _ , tt := range []* user.User {current , root } {
145+ b .Run (tt .Username , func (b * testing.B ) {
146+ for i := 0 ; i < b .N ; i ++ {
147+ GetGroupUser (tt )
148+ }
149+ })
150+ }
151+ }
152+
153+ func TestGetGroups (t * testing.T ) {
154+ userCurrent = mockUserCurrent
155+ userLookupGroupId = mockUserLookupGroupId
156+ groups , err := GetGroups ()
157+ if err != nil {
158+ t .Errorf ("GetGroups error = '%v', want nil" , err )
159+ } else if len (groups ) < 1 {
160+ t .Error ("GetGroups must return at least one group" )
161+ }
162+ }
163+
164+ func BenchmarkGetGroups (b * testing.B ) {
165+ for i := 0 ; i < b .N ; i ++ {
166+ GetGroups ()
167+ }
168+ }
169+
170+ var getGroupListTests = []struct {
171+ user , groups , err string
172+ }{
173+ {"root" , "root" , "" },
174+ {"testuser" , "testgroup" , "" },
175+ {"userwithnogroupid" , "" , "user: list groups for userwithnogroupid: invalid gid \" \" " },
176+ {"userwithinvalidgroup" , "nonexistentgroup" , "group: unknown group ID 1002" },
177+ {"nonexistentuser" , "nonexistentgroup" , "user: unknown user nonexistentuser" },
178+ }
179+
180+ func TestGetGroupList (t * testing.T ) {
181+ userLookup = mockUserLookup
182+ userLookupGroupId = mockUserLookupGroupId
183+ for _ , tt := range getGroupListTests {
184+ groups , err := GetGroupList (tt .user )
185+ if err != nil {
186+ if fmt .Sprintf ("%s" , err ) != tt .err {
187+ t .Errorf ("GetGroupList error = '%v', want '%v'" , err , tt .err )
188+ }
189+ } else {
190+ g := make ([]string , 0 , len (groups ))
191+ for group := range groups {
192+ g = append (g , group )
193+ }
194+ sort .Strings (g )
195+ if strings .Join (g , " " ) != tt .groups {
196+ t .Errorf ("GetGroupList groups = %v, want %v" , strings .Join (g , " " ), tt .groups )
197+ }
198+ }
199+ }
200+ }
201+
202+ func BenchmarkGetGroupList (b * testing.B ) {
203+ for _ , tt := range getGroupListTests {
204+ if tt .err == "" {
205+ b .Run (tt .user , func (b * testing.B ) {
206+ for i := 0 ; i < b .N ; i ++ {
207+ GetGroupList (tt .user )
208+ }
209+ })
210+ }
211+ }
212+ }
213+
100214func mockNetLookupHost (host string ) ([]string , error ) {
101215 if host == "err" {
102216 return nil , errors .New ("LookupHost error" )
@@ -106,6 +220,17 @@ func mockNetLookupHost(host string) ([]string, error) {
106220 return []string {"1.1.1.1" , "2.2.2.2" , "3.3.3.3" }, nil
107221}
108222
223+ func BenchmarkNormalizeHostPort (b * testing.B ) {
224+ netLookupHost = mockNetLookupHost
225+ for _ , tt := range []string {"127.0.0.1" , "127.0.0.1:123" , "server1" , "server1:123" , "host:port:invalid" , "err" } {
226+ b .Run (tt , func (b * testing.B ) {
227+ for i := 0 ; i < b .N ; i ++ {
228+ normalizeHostPort (tt )
229+ }
230+ })
231+ }
232+ }
233+
109234var matchSourceTests = []struct {
110235 source , sshdHostport string
111236 want bool
@@ -160,6 +285,16 @@ var matchSourceTests = []struct {
160285 "server2" ,
161286 true ,
162287 },
288+ {
289+ "127.0.0.1:22" ,
290+ "127.0.0.1:122" ,
291+ false ,
292+ },
293+ {
294+ "127.0.0.1:22" ,
295+ "127.0.0.2:22" ,
296+ false ,
297+ },
163298}
164299
165300func TestMatchSource (t * testing.T ) {
0 commit comments