Skip to content

Commit 723692c

Browse files
authored
refact pkg/cwhub (#3666)
* cwhub: stricter SafePath() * pkg/cwhub: compute hash on the fly * pkg/cwhub: named return values * pkg/cwhub: don't call panic() in tests * pkg/cwhub: refact relativePathComponents * pkg/cwhub: refact sync.go * pkg/cwhub: extract method collectSpecs() * pkg/cwhub: simplify url handling
1 parent f52d710 commit 723692c

File tree

9 files changed

+239
-120
lines changed

9 files changed

+239
-120
lines changed

pkg/cwhub/cwhub.go

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
package cwhub
22

33
import (
4-
"fmt"
54
"net/http"
6-
"path/filepath"
7-
"strings"
85
"time"
96

107
"github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent"
@@ -25,22 +22,3 @@ var HubClient = &http.Client{
2522
Timeout: 120 * time.Second,
2623
Transport: &hubTransport{http.DefaultTransport},
2724
}
28-
29-
// SafePath returns a joined path and ensures that it does not escape the base directory.
30-
func SafePath(dir, filePath string) (string, error) {
31-
absBaseDir, err := filepath.Abs(filepath.Clean(dir))
32-
if err != nil {
33-
return "", err
34-
}
35-
36-
absFilePath, err := filepath.Abs(filepath.Join(dir, filePath))
37-
if err != nil {
38-
return "", err
39-
}
40-
41-
if !strings.HasPrefix(absFilePath, absBaseDir) {
42-
return "", fmt.Errorf("path %s escapes base directory %s", filePath, dir)
43-
}
44-
45-
return absFilePath, nil
46-
}

pkg/cwhub/cwhub_test.go

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func testHubOld(t *testing.T, update bool) *Hub {
7575

7676
// envSetup initializes the temporary hub and mocks the http client.
7777
func envSetup(t *testing.T) *Hub {
78-
setResponseByPath()
78+
setResponseByPath(t)
7979
log.SetLevel(log.DebugLevel)
8080

8181
defaultTransport := HubClient.Transport
@@ -121,27 +121,23 @@ func (t *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
121121
return response, nil
122122
}
123123

124-
func fileToStringX(path string) string {
124+
func fileToStringX(t *testing.T, path string) string {
125125
f, err := os.Open(path)
126-
if err != nil {
127-
panic(err)
128-
}
126+
require.NoError(t, err)
129127
defer f.Close()
130128

131129
data, err := io.ReadAll(f)
132-
if err != nil {
133-
panic(err)
134-
}
130+
require.NoError(t, err)
135131

136132
return strings.ReplaceAll(string(data), "\r\n", "\n")
137133
}
138134

139-
func setResponseByPath() {
135+
func setResponseByPath(t *testing.T) {
140136
responseByPath = map[string]string{
141-
"/crowdsecurity/master/parsers/s01-parse/crowdsecurity/foobar_parser.yaml": fileToStringX("./testdata/foobar_parser.yaml"),
142-
"/crowdsecurity/master/parsers/s01-parse/crowdsecurity/foobar_subparser.yaml": fileToStringX("./testdata/foobar_parser.yaml"),
143-
"/crowdsecurity/master/collections/crowdsecurity/test_collection.yaml": fileToStringX("./testdata/collection_v1.yaml"),
144-
"/crowdsecurity/master/.index.json": fileToStringX("./testdata/index1.json"),
137+
"/crowdsecurity/master/parsers/s01-parse/crowdsecurity/foobar_parser.yaml": fileToStringX(t, "./testdata/foobar_parser.yaml"),
138+
"/crowdsecurity/master/parsers/s01-parse/crowdsecurity/foobar_subparser.yaml": fileToStringX(t, "./testdata/foobar_parser.yaml"),
139+
"/crowdsecurity/master/collections/crowdsecurity/test_collection.yaml": fileToStringX(t, "./testdata/collection_v1.yaml"),
140+
"/crowdsecurity/master/.index.json": fileToStringX(t, "./testdata/index1.json"),
145141
"/crowdsecurity/master/scenarios/crowdsecurity/foobar_scenario.yaml": `filter: true
146142
name: crowdsecurity/foobar_scenario`,
147143
"/crowdsecurity/master/scenarios/crowdsecurity/barfoo_scenario.yaml": `filter: true

pkg/cwhub/download.go

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,51 +32,39 @@ type ContentProvider interface {
3232
}
3333

3434
// urlTo builds the URL to download a file from the remote hub.
35-
func (d *Downloader) urlTo(remotePath string) (string, error) {
35+
func (d *Downloader) urlTo(remotePath string) (*url.URL, error) {
3636
// the template must contain two string placeholders
3737
if fmt.Sprintf(d.URLTemplate, "%s", "%s") != d.URLTemplate {
38-
return "", fmt.Errorf("invalid URL template '%s'", d.URLTemplate)
38+
return nil, fmt.Errorf("invalid URL template '%s'", d.URLTemplate)
3939
}
4040

41-
return fmt.Sprintf(d.URLTemplate, d.Branch, remotePath), nil
42-
}
41+
raw := fmt.Sprintf(d.URLTemplate, d.Branch, remotePath)
4342

44-
// addURLParam adds a parameter with a value (ex. "with_content=true") to the URL if it's not already present.
45-
func addURLParam(rawURL string, param string, value string) (string, error) {
46-
parsedURL, err := url.Parse(rawURL)
43+
parsed, err := url.Parse(raw)
4744
if err != nil {
48-
return "", fmt.Errorf("failed to parse URL: %w", err)
45+
return nil, fmt.Errorf("failed to parse URL: %w", err)
4946
}
5047

51-
query := parsedURL.Query()
52-
53-
if _, exists := query[param]; !exists {
54-
query.Add(param, value)
55-
}
56-
57-
parsedURL.RawQuery = query.Encode()
58-
59-
return parsedURL.String(), nil
48+
return parsed, nil
6049
}
6150

6251
// FetchIndex downloads the index from the hub and writes it to the filesystem.
6352
// It uses a temporary file to avoid partial downloads, and won't overwrite the original
6453
// if it has not changed.
6554
// Return true if the file has been updated, false if already up to date.
66-
func (d *Downloader) FetchIndex(ctx context.Context, destPath string, withContent bool, logger *logrus.Logger) (bool, error) {
55+
func (d *Downloader) FetchIndex(ctx context.Context, destPath string, withContent bool, logger *logrus.Logger) (downloaded bool, err error) {
6756
url, err := d.urlTo(".index.json")
6857
if err != nil {
6958
return false, fmt.Errorf("failed to build hub index request: %w", err)
7059
}
7160

7261
if withContent {
73-
url, err = addURLParam(url, "with_content", "true")
74-
if err != nil {
75-
return false, fmt.Errorf("failed to add 'with_content' parameter to URL: %w", err)
76-
}
62+
q := url.Query()
63+
q.Set("with_content", "true")
64+
url.RawQuery = q.Encode()
7765
}
7866

79-
downloaded, err := downloader.
67+
downloaded, err = downloader.
8068
New().
8169
WithHTTPClient(HubClient).
8270
ToFile(destPath).
@@ -86,7 +74,7 @@ func (d *Downloader) FetchIndex(ctx context.Context, destPath string, withConten
8674
BeforeRequest(func(_ *http.Request) {
8775
fmt.Println("Downloading " + destPath)
8876
}).
89-
Download(ctx, url)
77+
Download(ctx, url.String())
9078
if err != nil {
9179
return false, err
9280
}
@@ -97,13 +85,13 @@ func (d *Downloader) FetchIndex(ctx context.Context, destPath string, withConten
9785
// FetchContent downloads the content to the specified path, through a temporary file
9886
// to avoid partial downloads.
9987
// If the hash does not match, it will not overwrite and log a warning.
100-
func (d *Downloader) FetchContent(ctx context.Context, remotePath, destPath, wantHash string, logger *logrus.Logger) (bool, string, error) {
101-
url, err := d.urlTo(remotePath)
88+
func (d *Downloader) FetchContent(ctx context.Context, remotePath, destPath, wantHash string, logger *logrus.Logger) (downloaded bool, url string, err error) {
89+
u, err := d.urlTo(remotePath)
10290
if err != nil {
10391
return false, "", fmt.Errorf("failed to build request: %w", err)
10492
}
10593

106-
downloaded, err := downloader.
94+
downloaded, err = downloader.
10795
New().
10896
WithHTTPClient(HubClient).
10997
ToFile(destPath).
@@ -112,7 +100,7 @@ func (d *Downloader) FetchContent(ctx context.Context, remotePath, destPath, wan
112100
WithLogger(logger.WithField("url", url)).
113101
CompareContent().
114102
VerifyHash("sha256", wantHash).
115-
Download(ctx, url)
103+
Download(ctx, u.String())
116104

117105
var hasherr downloader.HashMismatchError
118106

@@ -123,5 +111,5 @@ func (d *Downloader) FetchContent(ctx context.Context, remotePath, destPath, wan
123111
return false, "", err
124112
}
125113

126-
return downloaded, url, nil
114+
return downloaded, u.String(), nil
127115
}

pkg/cwhub/fetch.go

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
package cwhub
22

33
import (
4+
"bytes"
45
"context"
56
"crypto"
67
"encoding/base64"
78
"encoding/hex"
89
"fmt"
10+
"io"
911
"os"
1012
"path/filepath"
13+
14+
"github.com/crowdsecurity/go-cs-lib/downloader"
1115
)
1216

1317
// writeEmbeddedContentTo writes the embedded content to the specified path and checks the hash.
@@ -24,24 +28,32 @@ func (i *Item) writeEmbeddedContentTo(destPath, wantHash string) error {
2428
}
2529

2630
dir := filepath.Dir(destPath)
31+
reader := bytes.NewReader(content)
32+
hash := crypto.SHA256.New()
2733

34+
tee := io.TeeReader(reader, hash)
2835
if err := os.MkdirAll(dir, 0o755); err != nil {
2936
return fmt.Errorf("while creating %s: %w", dir, err)
3037
}
3138

32-
// check sha256
33-
hash := crypto.SHA256.New()
34-
if _, err := hash.Write(content); err != nil {
35-
return fmt.Errorf("while hashing %s: %w", i.Name, err)
39+
f, err := os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600)
40+
if err != nil {
41+
return err
3642
}
3743

38-
gotHash := hex.EncodeToString(hash.Sum(nil))
39-
if gotHash != wantHash {
40-
return fmt.Errorf("hash mismatch: expected %s, got %s. The index file is invalid, please run 'cscli hub update' and try again", wantHash, gotHash)
44+
defer f.Close()
45+
46+
if _, err := io.Copy(f, tee); err != nil {
47+
return err
4148
}
4249

43-
if err := os.WriteFile(destPath, content, 0o600); err != nil {
44-
return fmt.Errorf("while writing %s: %w", destPath, err)
50+
gotHash := hex.EncodeToString(hash.Sum(nil))
51+
if gotHash != wantHash {
52+
return fmt.Errorf("%w. The index file is invalid, please run 'cscli hub update' and try again",
53+
downloader.HashMismatchError{
54+
Expected: wantHash,
55+
Got: gotHash,
56+
})
4557
}
4658

4759
return nil

pkg/cwhub/relativepath.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,25 @@ import (
55
"strings"
66
)
77

8-
// relativePathComponents returns the list of path components after baseDir.
9-
// If path is not inside baseDir, it returns an empty slice.
10-
func relativePathComponents(path string, baseDir string) []string {
8+
// relativePathComponents returns the list of path components after baseDir,
9+
// and a boolean indicating whether path is inside baseDir at all.
10+
func relativePathComponents(path string, baseDir string) ([]string, bool) {
1111
absPath, err := filepath.Abs(path)
1212
if err != nil {
13-
return []string{}
13+
// cwd disappeared??
14+
return nil, false
1415
}
1516

1617
absBaseDir, err := filepath.Abs(baseDir)
1718
if err != nil {
18-
return []string{}
19+
return nil, false
1920
}
2021

2122
// is path inside baseDir?
2223
relPath, err := filepath.Rel(absBaseDir, absPath)
2324
if err != nil || strings.HasPrefix(relPath, "..") || relPath == "." {
24-
return []string{}
25+
return nil, false
2526
}
2627

27-
return strings.Split(relPath, string(filepath.Separator))
28+
return strings.Split(relPath, string(filepath.Separator)), true
2829
}

pkg/cwhub/relativepath_test.go

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,62 +11,72 @@ func TestRelativePathComponents(t *testing.T) {
1111
name string
1212
path string
1313
baseDir string
14-
expected []string
14+
wantSubs []string
15+
wantOk bool
1516
}{
1617
{
1718
name: "Path within baseDir",
1819
path: "/home/user/project/src/file.go",
1920
baseDir: "/home/user/project",
20-
expected: []string{"src", "file.go"},
21+
wantSubs: []string{"src", "file.go"},
22+
wantOk: true,
2123
},
2224
{
2325
name: "Path is baseDir",
2426
path: "/home/user/project",
2527
baseDir: "/home/user/project",
26-
expected: []string{},
28+
wantSubs: nil,
29+
wantOk: false,
2730
},
2831
{
2932
name: "Path outside baseDir",
3033
path: "/home/user/otherproject/src/file.go",
3134
baseDir: "/home/user/project",
32-
expected: []string{},
35+
wantSubs: nil,
36+
wantOk: false,
3337
},
3438
{
3539
name: "Path is subdirectory of baseDir",
3640
path: "/home/user/project/src/",
3741
baseDir: "/home/user/project",
38-
expected: []string{"src"},
42+
wantSubs: []string{"src"},
43+
wantOk: true,
3944
},
4045
{
4146
name: "Relative paths",
4247
path: "project/src/file.go",
4348
baseDir: "project",
44-
expected: []string{"src", "file.go"},
49+
wantSubs: []string{"src", "file.go"},
50+
wantOk: true,
4551
},
4652
{
4753
name: "BaseDir with trailing slash",
4854
path: "/home/user/project/src/file.go",
4955
baseDir: "/home/user/project/",
50-
expected: []string{"src", "file.go"},
56+
wantSubs: []string{"src", "file.go"},
57+
wantOk: true,
5158
},
5259
{
5360
name: "Empty baseDir",
5461
path: "/home/user/project/src/file.go",
5562
baseDir: "",
56-
expected: []string{},
63+
wantSubs: nil,
64+
wantOk: false,
5765
},
5866
{
5967
name: "Empty path",
6068
path: "",
6169
baseDir: "/home/user/project",
62-
expected: []string{},
70+
wantSubs: nil,
71+
wantOk: false,
6372
},
6473
}
6574

6675
for _, tt := range tests {
6776
t.Run(tt.name, func(t *testing.T) {
68-
result := relativePathComponents(tt.path, tt.baseDir)
69-
assert.Equal(t, tt.expected, result)
77+
got, ok := relativePathComponents(tt.path, tt.baseDir)
78+
assert.Equal(t, tt.wantSubs, got)
79+
assert.Equal(t, tt.wantOk, ok)
7080
})
7181
}
7282
}

0 commit comments

Comments
 (0)