Skip to content

Commit 5c1def8

Browse files
add unit tests
1 parent 8563980 commit 5c1def8

File tree

2 files changed

+64
-5
lines changed

2 files changed

+64
-5
lines changed

pkg/server/smb/server.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func normalizeWindowsPath(path string) string {
2525
return normalizedPath
2626
}
2727

28-
func normalizeMappingPath(path string) string {
28+
func getRootMappingPath(path string) (string, error) {
2929
items := strings.Split(path, "\\")
3030
parts := []string{}
3131
for _, s := range items {
@@ -36,7 +36,13 @@ func normalizeMappingPath(path string) string {
3636
}
3737
}
3838
}
39-
return strings.ToLower("\\\\" + parts[0] + "\\" + parts[1])
39+
if len(parts) != 2 {
40+
klog.Errorf("remote path (%s) is invalid", path)
41+
return nil, fmt.Errorf("remote path (%s) is invalid", path)
42+
}
43+
// parts[0] is a smb host name
44+
// parts[1] is a smb share name
45+
return strings.ToLower("\\\\" + parts[0] + "\\" + parts[1]), nil
4046
}
4147

4248
func NewServer(hostAPI smb.API, fsServer *fsserver.Server) (*Server, error) {
@@ -57,7 +63,9 @@ func (s *Server) NewSmbGlobalMapping(context context.Context, request *internal.
5763
return response, fmt.Errorf("remote path is empty")
5864
}
5965

60-
mappingPath := normalizeMappingPath(remotePath)
66+
if mappingPath, err := getRootMappingPath(remotePath); err != nil {
67+
return response, err;
68+
}
6169

6270
isMapped, err := s.hostAPI.IsSmbMapped(mappingPath)
6371
if err != nil {
@@ -122,7 +130,10 @@ func (s *Server) RemoveSmbGlobalMapping(context context.Context, request *intern
122130
return response, fmt.Errorf("remote path is empty")
123131
}
124132

125-
mappingPath := normalizeMappingPath(remotePath)
133+
if mappingPath, err := getRootMappingPath(remotePath); err != nil {
134+
return response, err;
135+
}
136+
126137
err := s.hostAPI.RemoveSmbGlobalMapping(mappingPath)
127138
if err != nil {
128139
klog.Errorf("failed RemoveSmbGlobalMapping %v", err)

pkg/server/smb/server_test.go

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ func TestNewSmbGlobalMapping(t *testing.T) {
7979
expectError: true,
8080
},
8181
{
82-
remote: "\\test\\path",
82+
remote: "\\\\hostname\\path",
8383
username: "",
8484
password: "",
8585
version: v1,
@@ -111,3 +111,51 @@ func TestNewSmbGlobalMapping(t *testing.T) {
111111
}
112112
}
113113
}
114+
115+
func TestGetRootMappingPath(t *testing.T) {
116+
testCases := []struct {
117+
remote string
118+
expectResult string
119+
expectError bool
120+
}{
121+
{
122+
remote: "",
123+
expectResult: nil,
124+
expectError: true,
125+
},
126+
{
127+
remote: "hostname",
128+
expectResult: nil,
129+
expectError: true,
130+
},
131+
{
132+
remote: "\\\\hostname\\path",
133+
expectResult: "\\\\hostname\\path",
134+
expectError: false,
135+
},
136+
{
137+
remote: "\\\\hostname\\path\\",
138+
expectResult: "\\\\hostname\\path",
139+
expectError: false,
140+
},
141+
{
142+
remote: "\\\\hostname\\path\\subpath",
143+
expectResult: "\\\\hostname\\path",
144+
expectError: false,
145+
},
146+
}
147+
for _, tc := range testCases {
148+
result, err := getRootMappingPath(tc.remote)
149+
if tc.expectError && err == nil {
150+
t.Errorf("Expected error but getRootMappingPath returned a nil error")
151+
}
152+
if !tc.expectError {
153+
if err != nil {
154+
t.Errorf("Expected no errors but getRootMappingPath returned error: %v", err)
155+
}
156+
if expectResult != result {
157+
t.Errorf("Expected (%s) but getRootMappingPath returned (%s)", expectResult, result)
158+
}
159+
}
160+
}
161+
}

0 commit comments

Comments
 (0)