Skip to content

Commit 3442c47

Browse files
authored
fix: prevent ScanFile... functions to consume too much memory when scanning large files (#42)
The previous implementation always read the whole file into a memory buffer before sending it to VirusTotal, this imposes a large memory footprint when scanning large files. With the new implementation, we try to determine the payload size without reading the whole file. If the size can't be determined, then we fall back to reading the whole file into a buffer. In most cases reading the whole file is not required. Closes #41.
1 parent b198130 commit 3442c47

File tree

1 file changed

+63
-26
lines changed

1 file changed

+63
-26
lines changed

filescan.go

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"mime/multipart"
2222
"net/url"
2323
"os"
24+
"strings"
2425
)
2526

2627
type progressReader struct {
@@ -46,40 +47,76 @@ type FileScanner struct {
4647

4748
func (s *FileScanner) scanWithParameters(
4849
r io.Reader, filename string, progress chan<- float32, parameters map[string]string) (*Object, error) {
49-
var uploadURL *url.URL
50-
var payloadSize int64
5150

52-
b := bytes.Buffer{}
51+
// File size is initially unknown.
52+
fileSize := int64(-1)
5353

54-
// Create multipart writer for the file
55-
w := multipart.NewWriter(&b)
56-
f, err := w.CreateFormFile("file", filename)
57-
if err != nil {
58-
return nil, err
54+
// Try to determine the size of the file being uploaded.
55+
switch v := r.(type) {
56+
case *os.File:
57+
if stat, err := v.Stat(); err == nil {
58+
fileSize = stat.Size()
59+
}
60+
case *bytes.Buffer:
61+
fileSize = int64(v.Len())
62+
case *bytes.Reader:
63+
fileSize = int64(v.Len())
64+
case *strings.Reader:
65+
fileSize = int64(v.Len())
66+
default:
5967
}
6068

61-
// Copy data from input stream to the multiparted file
62-
if payloadSize, err = io.Copy(f, r); err != nil {
63-
return nil, err
69+
// If the size was not determined by other means, read the entire
70+
// content into a buffer to determine the size.
71+
if fileSize == -1 {
72+
b := bytes.Buffer{}
73+
io.Copy(&b, r)
74+
fileSize = int64(b.Len())
75+
r = &b
6476
}
6577

66-
if parameters != nil {
78+
pipeReader, pipeWriter := io.Pipe()
79+
multipartWriter := multipart.NewWriter(pipeWriter)
80+
81+
// Read data from the input reader `r`, and write it into the multipart
82+
// writer in a separate goroutine using a pipe. Data is read from `r`
83+
// only as requested by the HTTP client to avoid loading all the data
84+
// into memory.
85+
go func() {
86+
defer pipeWriter.Close()
87+
defer multipartWriter.Close()
88+
89+
f, err := multipartWriter.CreateFormFile("file", filename)
90+
if err != nil {
91+
pipeWriter.CloseWithError(err)
92+
return
93+
}
94+
95+
if _, err := io.Copy(f, r); err != nil {
96+
pipeWriter.CloseWithError(err)
97+
return
98+
}
99+
67100
for key, val := range parameters {
68-
if err := w.WriteField(key, val); err != nil {
69-
return nil, err
101+
if err := multipartWriter.WriteField(key, val); err != nil {
102+
pipeWriter.CloseWithError(err)
103+
return
70104
}
71105
}
72-
}
106+
}()
73107

74-
w.Close()
108+
var uploadURL *url.URL
109+
var err error
75110

76-
if payloadSize > maxFileSize {
111+
// Choose upload URL based on the file size. If the size is known and less
112+
// than maxPayloadSize, we can upload directly to /files. If the size is
113+
// unknown or larger than maxPayloadSize, we need to request an upload URL
114+
// first. If the size is larger than maxFileSize, we return an error.
115+
if fileSize > maxFileSize {
77116
return nil, fmt.Errorf("file size can't be larger than %d bytes", maxFileSize)
78-
} else if payloadSize > maxPayloadSize {
79-
// Payload is bigger than supported by AppEngine in a POST request,
80-
// let's ask for an upload URL.
117+
} else if fileSize > maxPayloadSize {
81118
var u string
82-
if _, err := s.cli.GetData(URL("files/upload_url"), &u); err != nil {
119+
if _, err = s.cli.GetData(URL("files/upload_url"), &u); err != nil {
83120
return nil, err
84121
}
85122
if uploadURL, err = url.Parse(u); err != nil {
@@ -89,14 +126,14 @@ func (s *FileScanner) scanWithParameters(
89126
uploadURL = URL("files")
90127
}
91128

92-
pr := &progressReader{
93-
reader: &b,
94-
total: int64(b.Len()),
129+
progressReader := &progressReader{
130+
reader: pipeReader,
131+
total: fileSize,
95132
progressCh: progress}
96133

97-
headers := map[string]string{"Content-Type": w.FormDataContentType()}
134+
headers := map[string]string{"Content-Type": multipartWriter.FormDataContentType()}
98135

99-
httpResp, err := s.cli.sendRequest("POST", uploadURL, pr, headers)
136+
httpResp, err := s.cli.sendRequest("POST", uploadURL, progressReader, headers)
100137
if err != nil {
101138
return nil, err
102139
}

0 commit comments

Comments
 (0)