Skip to content

Commit e671a5d

Browse files
authored
feat: recusively scan folders [#120] (#124)
This PR refactors NewFileDirReader to use Go’s standard filepath.WalkDir (here) for directory traversal, and avoid custom recursive implementations. The function efficiently reads all files under a directory, optionally traversing subdirectories up to a configurable maxDepth. Traversal is stopped immediately for directories beyond the allowed depth or when recursion is disabled using fs.SkipDir, avoiding unnecessary filesystem reads and reducing resource usage. This approach is safe, efficient, and reduces the risk of errors compared to manual recursion, while remaining idiomatic and easy to maintain. Usage: ``` vt scan file YOUR_DIRECTORY (default recursive disabled + maxDepth == 1) vt scan file YOUR_DIRECTORY -r (recursive enabled with default maxDepth == 1) vt scan file YOUR_DIRECTORY -r -d 3 (recursive enabled + setting up maxDepth to read more subtrees until reaching depth == 3 on each subdirectory) ```
1 parent d340558 commit e671a5d

File tree

6 files changed

+290
-15
lines changed

6 files changed

+290
-15
lines changed

cmd/cmd.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,18 @@ func addThreadsFlag(flags *pflag.FlagSet) {
7474
"number of threads working in parallel")
7575
}
7676

77+
func addRecursive(flags *pflag.FlagSet) {
78+
flags.BoolP(
79+
"recursive", "r", false,
80+
"enable recursive traversal of subdirectories")
81+
}
82+
83+
func addMaxDepth(flags *pflag.FlagSet) {
84+
flags.IntP(
85+
"maxDepth", "d", 1,
86+
"maximum recursion depth for directory traversal")
87+
}
88+
7789
func addIDOnlyFlag(flags *pflag.FlagSet) {
7890
flags.BoolP(
7991
"identifiers-only", "I", false,

cmd/scan.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ func NewScanFileCmd() *cobra.Command {
164164
if len(args) == 1 && args[0] == "-" {
165165
argReader = utils.NewStringIOReader(os.Stdin)
166166
} else if len(args) == 1 && utils.IsDir(args[0]) {
167-
argReader, _ = utils.NewFileDirReader(args[0])
167+
recursive := viper.GetBool("recursive")
168+
maxDepth := viper.GetInt("maxDepth")
169+
argReader, _ = utils.NewFileDirReader(args[0], recursive, maxDepth)
168170
} else {
169171
argReader = utils.NewStringArrayReader(args)
170172
}
@@ -188,6 +190,8 @@ func NewScanFileCmd() *cobra.Command {
188190
},
189191
}
190192

193+
addRecursive(cmd.Flags())
194+
addMaxDepth(cmd.Flags())
191195
addThreadsFlag(cmd.Flags())
192196
addOpenInVTFlag(cmd.Flags())
193197
addPasswordFlag(cmd.Flags())

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ require (
2424
github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect
2525
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
2626
github.com/fsnotify/fsnotify v1.7.0 // indirect
27+
github.com/google/go-cmp v0.7.0 // indirect
2728
github.com/hashicorp/hcl v1.0.0 // indirect
2829
github.com/inconshreveable/mousetrap v1.1.0 // indirect
2930
github.com/magiconair/properties v1.8.7 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
2424
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
2525
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
2626
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
27+
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
28+
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
2729
github.com/gosuri/uitable v0.0.4 h1:IG2xLKRvErL3uhY6e1BylFzG+aJiwQviDDTfOKeKTpY=
2830
github.com/gosuri/uitable v0.0.4/go.mod h1:tKR86bXuXPZazfOTG1FIzvjIdXzd0mo4Vtn16vt0PJo=
2931
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=

utils/file_utils.go

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,56 @@
1414
package utils
1515

1616
import (
17+
"io/fs"
1718
"os"
18-
"path"
19+
"path/filepath"
20+
"strings"
1921
)
2022

21-
// FileDirReader returns all files inside a given directory
22-
// as a StringArrayReader
23-
func NewFileDirReader(fileDir string) (*StringArrayReader, error) {
24-
files, err := os.ReadDir(fileDir)
23+
// NewFileDirReader reads all files from the given directory `fileDir`.
24+
// It can optionally traverse subdirectories if `recursive` is true,
25+
// and will limit recursion to `maxDepth` levels if specified.
26+
//
27+
// Uses the standard library's `filepath.WalkDir` to traverse directories efficiently,
28+
// and `fs.SkipDir` to skip directories when recursion is disabled or maxDepth is reached.
29+
func NewFileDirReader(fileDir string, recursive bool, maxDepth int) (*StringArrayReader, error) {
30+
var filePaths []string
31+
rootDepth := pathDepth(fileDir)
32+
33+
// filePaths is safely appended within WalkDir because WalkDir executes the callback sequentially.
34+
// No race conditions occur in this implementation, even with slice reallocation.
35+
err := filepath.WalkDir(fileDir, func(path string, d fs.DirEntry, err error) error {
36+
if err != nil {
37+
return err
38+
}
39+
40+
if !d.IsDir() {
41+
filePaths = append(filePaths, path)
42+
return nil
43+
}
44+
45+
currentDepth := pathDepth(path) - rootDepth
46+
// we skip directory if recursive is disabled or
47+
// if we reached configured maxDepth
48+
if !recursive && path != fileDir ||
49+
currentDepth >= maxDepth {
50+
return fs.SkipDir
51+
}
52+
53+
return nil
54+
})
55+
2556
if err != nil {
2657
return nil, err
2758
}
28-
fileNames := []string{}
29-
for _, f := range files {
30-
// Skip subdirectories
31-
if f.IsDir() {
32-
continue
33-
}
34-
fileNames = append(fileNames, path.Join(fileDir, f.Name()))
35-
}
36-
return &StringArrayReader{strings: fileNames}, nil
59+
return &StringArrayReader{strings: filePaths}, nil
60+
}
61+
62+
// pathDepth returns the depth of a given path by counting its components.
63+
// It uses filepath.Separator, which ensures correct behavior across all platforms
64+
// (Windows, macOS, Linux), regardless of the underlying path separator.
65+
func pathDepth(path string) int {
66+
return len(strings.Split(filepath.Clean(path), string(filepath.Separator)))
3767
}
3868

3969
// IsDir function returns whether a file is a directory or not

utils/file_utils_test.go

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
// Copyright © 2017 The VirusTotal CLI authors. All Rights Reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package utils
15+
16+
import (
17+
"os"
18+
"path/filepath"
19+
"strings"
20+
"testing"
21+
22+
"github.com/google/go-cmp/cmp"
23+
)
24+
25+
func Test_NewFileDirReader(t *testing.T) {
26+
t.Parallel()
27+
28+
useCases := []struct {
29+
name string
30+
31+
directories []string
32+
files []string
33+
recursive bool
34+
maxDepth int
35+
36+
want func(string) *StringArrayReader
37+
}{
38+
39+
{
40+
name: "want get back empty files",
41+
directories: []string{},
42+
files: []string{},
43+
recursive: true,
44+
maxDepth: 2,
45+
want: func(d string) *StringArrayReader {
46+
return &StringArrayReader{
47+
strings: nil,
48+
}
49+
},
50+
},
51+
{
52+
name: "want to read single file",
53+
directories: []string{},
54+
files: []string{"z.txt"},
55+
recursive: true,
56+
maxDepth: 2,
57+
want: func(d string) *StringArrayReader {
58+
return &StringArrayReader{
59+
strings: []string{
60+
filepath.Join(d, "z.txt"),
61+
},
62+
}
63+
},
64+
},
65+
{
66+
name: "want read all files within all subdirectories",
67+
directories: []string{"sub", "sub/sub", ".hidden"},
68+
files: []string{"a.txt", "b.txt", "sub/c.txt", "sub/sub/d.txt", ".hidden/config"},
69+
recursive: true,
70+
maxDepth: 3,
71+
want: func(d string) *StringArrayReader {
72+
return &StringArrayReader{
73+
strings: []string{
74+
filepath.Join(d, ".hidden/config"),
75+
filepath.Join(d, "a.txt"),
76+
filepath.Join(d, "b.txt"),
77+
filepath.Join(d, "sub/c.txt"),
78+
filepath.Join(d, "sub/sub/d.txt"),
79+
},
80+
}
81+
},
82+
},
83+
{
84+
name: "want to ignore all subdirectories",
85+
directories: []string{"sub", "sub/sub"},
86+
files: []string{"a.txt", "b.txt", "sub/c.txt", "sub/sub/d.txt"},
87+
recursive: false,
88+
maxDepth: 10,
89+
want: func(d string) *StringArrayReader {
90+
return &StringArrayReader{
91+
strings: []string{
92+
filepath.Join(d, "a.txt"),
93+
filepath.Join(d, "b.txt"),
94+
},
95+
}
96+
},
97+
},
98+
{
99+
name: "want to read until first depth",
100+
directories: []string{"sub", "sub/sub"},
101+
files: []string{"a.txt", "b.txt", "sub/c.txt", "sub/sub/d.txt"},
102+
recursive: true,
103+
maxDepth: 1,
104+
want: func(d string) *StringArrayReader {
105+
return &StringArrayReader{
106+
strings: []string{
107+
filepath.Join(d, "a.txt"),
108+
filepath.Join(d, "b.txt"),
109+
},
110+
}
111+
},
112+
},
113+
{
114+
name: "want to read until second depth",
115+
directories: []string{"sub", "sub/sub"},
116+
files: []string{"a.txt", "b.txt", "sub/c.txt", "sub/sub/d.txt"},
117+
recursive: true,
118+
maxDepth: 2,
119+
want: func(d string) *StringArrayReader {
120+
return &StringArrayReader{
121+
strings: []string{
122+
filepath.Join(d, "a.txt"),
123+
filepath.Join(d, "b.txt"),
124+
filepath.Join(d, "sub/c.txt"),
125+
},
126+
}
127+
},
128+
},
129+
}
130+
131+
for _, uc := range useCases {
132+
t.Run(uc.name, func(t *testing.T) {
133+
134+
// create a temp directory, and will clean up after test ends
135+
rootDir := t.TempDir()
136+
137+
for _, d := range uc.directories {
138+
path := filepath.Join(rootDir, d)
139+
rwxPerm := os.FileMode(0755)
140+
if err := os.Mkdir(path, rwxPerm); err != nil {
141+
t.Fatalf("unexpected error while Mkdir %v", err)
142+
}
143+
}
144+
145+
for _, f := range uc.files {
146+
path := filepath.Join(rootDir, f)
147+
rwPerm := os.FileMode(0644)
148+
if err := os.WriteFile(path, []byte("hello world!"), rwPerm); err != nil {
149+
t.Fatalf("unexpected error while WriteFile %v", err)
150+
}
151+
}
152+
153+
got, err := NewFileDirReader(rootDir, uc.recursive, uc.maxDepth)
154+
if err != nil {
155+
t.Errorf("unexpected error while NewFileDirReader err:%v", err)
156+
}
157+
if diff := cmp.Diff(
158+
uc.want(rootDir),
159+
got,
160+
cmp.AllowUnexported(StringArrayReader{}),
161+
); diff != "" {
162+
t.Errorf("unexpected StringArrayReader mismatch (-want +got):\n%s", diff)
163+
}
164+
})
165+
}
166+
}
167+
168+
func Test_NewFileDirReader_Error(t *testing.T) {
169+
t.Parallel()
170+
171+
rootDir := t.TempDir()
172+
noPerm := os.FileMode(0000)
173+
if err := os.WriteFile(filepath.Join(rootDir, "a.txt"), []byte("hello world!"), noPerm); err != nil {
174+
t.Fatalf("unexpected error while WriteFile %v", err)
175+
}
176+
path := filepath.Join(rootDir, "sub")
177+
if err := os.Mkdir(path, noPerm); err != nil {
178+
t.Fatalf("unexpected error while Mkdir %v", err)
179+
}
180+
_, err := NewFileDirReader(rootDir, false, 10)
181+
if err != nil {
182+
t.Errorf("unexpected error while NewFileDirReader err:%v", err)
183+
}
184+
_, err = NewFileDirReader(rootDir, true, 10)
185+
if !strings.Contains(err.Error(), "permission denied") {
186+
t.Errorf("unexpected error permissions denied message got:%v", err.Error())
187+
}
188+
}
189+
190+
func Test_pathDepth(t *testing.T) {
191+
t.Parallel()
192+
193+
useCases := []struct {
194+
name string
195+
dir string
196+
want int
197+
}{
198+
{
199+
name: "want one depth when empty directory",
200+
want: 1,
201+
},
202+
{
203+
name: "want one depth when simple directory",
204+
dir: "a",
205+
want: 1,
206+
},
207+
{
208+
name: "want 2 depth when sub directory",
209+
dir: "a/a",
210+
want: 2,
211+
},
212+
{
213+
name: "want 3 depth when sub directory",
214+
dir: "a-a/b_bb__/cccc/",
215+
want: 3,
216+
},
217+
}
218+
219+
for _, uc := range useCases {
220+
t.Run(uc.name, func(t *testing.T) {
221+
if got, want := pathDepth(uc.dir), uc.want; got != want {
222+
t.Errorf("unexpected pathDepth, got:%v, want:%v", got, want)
223+
}
224+
})
225+
}
226+
}

0 commit comments

Comments
 (0)