Skip to content

Commit 9134651

Browse files
committed
Automatically remove files when idle + add two tests
1 parent c809918 commit 9134651

File tree

3 files changed

+214
-29
lines changed

3 files changed

+214
-29
lines changed

db.go

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ import (
99
type db struct {
1010
*sql.DB
1111

12-
addFile *sql.Stmt
13-
remFile *sql.Stmt
14-
getFile *sql.Stmt
12+
addFile *sql.Stmt
13+
remFile *sql.Stmt
14+
contentType *sql.Stmt
1515

16-
addUse *sql.Stmt
16+
addUse *sql.Stmt
1717
shouldDelete *sql.Stmt
18+
cleanup *sql.Stmt
19+
pendingCleanup *sql.Stmt
1820
}
1921

2022
func openDB(driver, dsn string) (*db, error) {
@@ -61,9 +63,10 @@ func openDB(driver, dsn string) (*db, error) {
6163
func (db *db) initSchema() error {
6264
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS filedrop (
6365
uuid TEXT PRIMARY KEY NOT NULL,
66+
contentType TEXT DEFAULT NULL,
6467
uses INTEGER NOT NULL DEFAULT 0,
65-
maxUses INTEGER,
66-
storeUntil INTEGER
68+
maxUses INTEGER DEFAULT NULL,
69+
storeUntil INTEGER DEFAULT NULL
6770
)`)
6871
if err != nil {
6972
return err
@@ -73,15 +76,15 @@ func (db *db) initSchema() error {
7376

7477
func (db *db) initStmts() error {
7578
var err error
76-
db.addFile, err = db.Prepare(`INSERT INTO filedrop(uuid, maxUses, storeUntil) VALUES (?, ?, ?)`)
79+
db.addFile, err = db.Prepare(`INSERT INTO filedrop(uuid, contentType, maxUses, storeUntil) VALUES (?, ?, ?, ?)`)
7780
if err != nil {
7881
return err
7982
}
8083
db.remFile, err = db.Prepare(`DELETE FROM filedrop WHERE uuid = ?`)
8184
if err != nil {
8285
return err
8386
}
84-
db.getFile, err = db.Prepare(`SELECT uses, maxUses, storeUntil FROM filedrop WHERE uuid = ?`)
87+
db.contentType, err = db.Prepare(`SELECT contentType FROM filedrop WHERE uuid = ?`)
8588
if err != nil {
8689
return err
8790
}
@@ -93,18 +96,27 @@ func (db *db) initStmts() error {
9396
if err != nil {
9497
return err
9598
}
99+
db.pendingCleanup, err = db.Prepare(`SELECT uuid FROM filedrop WHERE storeUntil < ? OR maxUses == uses`)
100+
if err != nil {
101+
return err
102+
}
103+
db.cleanup, err = db.Prepare(`DELETE FROM filedrop WHERE storeUntil < ? OR maxUses == uses`)
104+
if err != nil {
105+
return err
106+
}
96107
return nil
97108
}
98109

99-
func (db *db) AddFile(tx *sql.Tx, uuid string, maxUses uint, storeUntil time.Time) error {
110+
func (db *db) AddFile(tx *sql.Tx, uuid string, contentType string, maxUses uint, storeUntil time.Time) error {
100111
maxUsesN := sql.NullInt64{Int64: int64(maxUses), Valid: maxUses != 0}
101112
storeUntilN := sql.NullInt64{Int64: storeUntil.Unix(), Valid: !storeUntil.IsZero()}
113+
contentTypeN := sql.NullString{String: contentType, Valid: contentType != ""}
102114

103115
if tx != nil {
104-
_, err := tx.Stmt(db.addFile).Exec(uuid, maxUsesN, storeUntilN)
116+
_, err := tx.Stmt(db.addFile).Exec(uuid, contentTypeN, maxUsesN, storeUntilN)
105117
return err
106118
} else {
107-
_, err := db.addFile.Exec(uuid, maxUsesN, storeUntilN)
119+
_, err := db.addFile.Exec(uuid, contentTypeN, maxUsesN, storeUntilN)
108120
return err
109121
}
110122
}
@@ -141,4 +153,47 @@ func (db *db) AddUse(tx *sql.Tx, uuid string) error {
141153
_, err := db.addUse.Exec(uuid, uuid)
142154
return err
143155
}
156+
}
157+
158+
func (db *db) ContentType(tx *sql.Tx, fileUUID string) (string, error) {
159+
var row *sql.Row
160+
if tx != nil {
161+
row = tx.Stmt(db.contentType).QueryRow(fileUUID)
162+
} else {
163+
row = db.contentType.QueryRow(fileUUID)
164+
}
165+
166+
res := ""
167+
return res, row.Scan(&res)
168+
}
169+
170+
func (db *db) UnreachableFiles(tx *sql.Tx) ([]string, error) {
171+
uuids := []string{}
172+
var rows *sql.Rows
173+
var err error
174+
if tx != nil {
175+
rows, err = tx.Stmt(db.pendingCleanup).Query(time.Now().Unix())
176+
} else {
177+
rows, err = db.pendingCleanup.Query(time.Now().Unix())
178+
}
179+
if err != nil {
180+
return uuids, err
181+
}
182+
for rows.Next() {
183+
uuid := ""
184+
if err := rows.Scan(&uuid); err != nil {
185+
return uuids, err
186+
}
187+
}
188+
return uuids, nil
189+
}
190+
191+
func (db *db) RemoveUnreachableFiles(tx *sql.Tx) error {
192+
if tx != nil {
193+
_, err := tx.Stmt(db.cleanup).Exec(time.Now().Unix())
194+
return err
195+
} else {
196+
_, err := db.cleanup.Exec(time.Now().Unix())
197+
return err
198+
}
144199
}

server.go

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,27 @@ type Server struct {
2323
DB *db
2424
Conf Config
2525
Logger *log.Logger
26+
27+
fileCleanerStopChan chan bool
2628
}
2729

2830
func New(conf Config) (*Server, error) {
2931
s := new(Server)
3032
var err error
3133

3234
s.Conf = conf
35+
s.fileCleanerStopChan = make(chan bool)
3336
s.Logger = log.New(os.Stderr, "filedrop ", log.LstdFlags)
3437
s.DB, err = openDB(conf.DB.Driver, conf.DB.DSN)
38+
39+
go s.fileCleaner()
40+
3541
return s, err
3642
}
3743

3844
// AddFile adds file to storage and returns assigned UUID which can be directly
3945
// substituted into URL.
40-
func (s *Server) AddFile(contents io.Reader, maxUses uint, storeUntil time.Time) (string, error) {
46+
func (s *Server) AddFile(contents io.Reader, contentType string, maxUses uint, storeUntil time.Time) (string, error) {
4147
fileUUID := uuid.NewV4()
4248
outLocation := filepath.Join(s.Conf.StorageDir, fileUUID.String())
4349

@@ -53,7 +59,7 @@ func (s *Server) AddFile(contents io.Reader, maxUses uint, storeUntil time.Time)
5359
if _, err := io.Copy(file, contents); err != nil {
5460
return "", errors.Wrap(err, "file write")
5561
}
56-
if err := s.DB.AddFile(nil, fileUUID.String(), maxUses, storeUntil); err != nil {
62+
if err := s.DB.AddFile(nil, fileUUID.String(), contentType, maxUses, storeUntil); err != nil {
5763
os.Remove(outLocation)
5864
return "", errors.Wrap(err, "db add")
5965
}
@@ -110,48 +116,54 @@ func (s *Server) OpenFile(fileUUID string) (io.Reader, error) {
110116
// Note that access using this function is equivalent to access
111117
// through HTTP API, so it will count against usage count, for example.
112118
// To avoid this use OpenFile(fileUUID).
113-
func (s *Server) GetFile(fileUUID string) (io.Reader, error) {
119+
func (s *Server) GetFile(fileUUID string) (r io.Reader, contentType string, err error) {
114120
// Just to check validity.
115-
_, err := uuid.FromString(fileUUID)
121+
_, err = uuid.FromString(fileUUID)
116122
if err != nil {
117-
return nil, errors.Wrap(err, "uuid parse")
123+
return nil, "", errors.Wrap(err, "uuid parse")
118124
}
119125

120126
tx, err := s.DB.Begin()
121127
if err != nil {
122-
return nil, errors.Wrap(err, "tx begin")
128+
return nil, "", errors.Wrap(err, "tx begin")
123129
}
124130
defer tx.Rollback() // rollback is no-op after commit
125131

126132
if s.DB.ShouldDelete(tx, fileUUID) {
127133
if err := s.removeFile(tx, fileUUID); err != nil {
128-
log.Println("Error while trying to remove file", fileUUID + ":", err)
134+
s.Logger.Println("Error while trying to remove file", fileUUID + ":", err)
129135

130136
}
131137
if err := tx.Commit(); err != nil {
132-
log.Println("Tx commit error:", err)
133-
return nil, err
138+
s.Logger.Println("Tx commit error:", err)
139+
return nil, "", err
134140
}
135-
return nil, ErrFileDoesntExists
141+
return nil, "", ErrFileDoesntExists
136142
}
137143
if err := s.DB.AddUse(tx, fileUUID); err != nil {
138-
return nil, errors.Wrap(err, "add use")
144+
return nil, "", errors.Wrap(err, "add use")
139145
}
140146

141147
fileLocation := filepath.Join(s.Conf.StorageDir, fileUUID)
142148
file, err := os.Open(fileLocation)
143149
if err != nil {
144150
if os.IsNotExist(err) {
145151
s.removeFile(tx, fileUUID)
146-
return nil, ErrFileDoesntExists
152+
return nil, "", ErrFileDoesntExists
147153
}
148-
return nil, err
154+
return nil, "", err
149155
}
150156
if err := tx.Commit(); err != nil {
151-
log.Println("Tx commit error:", err)
152-
return nil, errors.Wrap(err, "tx commit")
157+
s.Logger.Println("Tx commit error:", err)
158+
return nil, "", errors.Wrap(err, "tx commit")
153159
}
154-
return file, nil
160+
161+
ttype, err := s.DB.ContentType(nil, fileUUID)
162+
if err != nil {
163+
return nil, "", errors.Wrap(err, "content type query")
164+
}
165+
166+
return file, ttype, nil
155167
}
156168

157169
func (s *Server) acceptFile(w http.ResponseWriter, r *http.Request) {
@@ -205,7 +217,7 @@ func (s *Server) acceptFile(w http.ResponseWriter, r *http.Request) {
205217
}
206218
}
207219

208-
fileUUID, err := s.AddFile(r.Body, maxUses, storeUntil)
220+
fileUUID, err := s.AddFile(r.Body, r.Header.Get("Content-Type"), maxUses, storeUntil)
209221
if err != nil {
210222
w.WriteHeader(http.StatusInternalServerError)
211223
w.Write([]byte(err.Error()))
@@ -244,7 +256,7 @@ func (s *Server) serveFile(w http.ResponseWriter, r *http.Request) {
244256
return
245257
}
246258
fileUUID := splittenPath[len(splittenPath)-2]
247-
reader, err := s.GetFile(fileUUID)
259+
reader, ttype, err := s.GetFile(fileUUID)
248260
if err != nil {
249261
if err == ErrFileDoesntExists {
250262
w.WriteHeader(http.StatusNotFound)
@@ -256,6 +268,9 @@ func (s *Server) serveFile(w http.ResponseWriter, r *http.Request) {
256268
}
257269
return
258270
}
271+
if ttype != "" {
272+
w.Header().Set("Content-Type", ttype)
273+
}
259274
w.WriteHeader(http.StatusOK)
260275
_, err = io.Copy(w, reader)
261276
if err != nil {
@@ -276,5 +291,47 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
276291
}
277292

278293
func (s *Server) Close() error {
294+
// don't close DB if "cleaner" is doing something, wait for it to finish
295+
s.fileCleanerStopChan <- true
296+
<-s.fileCleanerStopChan
297+
279298
return s.DB.Close()
280299
}
300+
301+
func (s *Server) fileCleaner() {
302+
tick := time.NewTicker(time.Minute)
303+
for {
304+
select {
305+
case <-s.fileCleanerStopChan:
306+
s.fileCleanerStopChan <- true
307+
return
308+
case <-tick.C:
309+
s.cleanupFiles()
310+
}
311+
}
312+
}
313+
314+
func (s *Server) cleanupFiles() {
315+
tx, err := s.DB.Begin()
316+
if err != nil {
317+
s.Logger.Println("Failed to begin transaction for clean-up:", err)
318+
return
319+
}
320+
defer tx.Rollback() // rollback is no-op after commit
321+
322+
uuids, err := s.DB.UnreachableFiles(tx)
323+
if err != nil {
324+
s.Logger.Println("Failed to get list of files pending removal:", err)
325+
return
326+
}
327+
328+
for _, fileUUID := range uuids {
329+
if err := os.Remove(filepath.Join(s.Conf.StorageDir, fileUUID)); err !=nil {
330+
s.Logger.Println("Failed to remove file during clean-up:", err)
331+
}
332+
}
333+
334+
if err := tx.Commit(); err != nil {
335+
s.Logger.Println("Failed to begin transaction for clean-up:", err)
336+
}
337+
}

server_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,76 @@ func TestNonExistent(t *testing.T) {
210210
t.FailNow()
211211
}
212212
}
213+
214+
func TestContentTypePreserved(t *testing.T) {
215+
serv := initServ(filedrop.Default)
216+
ts := httptest.NewServer(serv)
217+
defer os.RemoveAll(serv.Conf.StorageDir)
218+
defer serv.Close()
219+
defer ts.Close()
220+
c := ts.Client()
221+
222+
url := string(doPOST(t, c, ts.URL + "/filedrop/meow.txt", "text/kitteh", strings.NewReader(file)))
223+
224+
t.Log("File URL:", url)
225+
226+
resp, err := c.Get(url)
227+
if err != nil {
228+
t.Error("GET:", err)
229+
t.FailNow()
230+
}
231+
defer resp.Body.Close()
232+
body, err := ioutil.ReadAll(resp.Body)
233+
if err != nil {
234+
t.Error("ioutil.ReadAll:", err)
235+
t.FailNow()
236+
}
237+
if resp.StatusCode / 100 != 2 {
238+
t.Error("GET: HTTP", resp.Status)
239+
t.Error("Body:", string(body))
240+
t.FailNow()
241+
}
242+
if resp.Header.Get("Content-Type") != "text/kitteh" {
243+
t.Log("Mismatched content type:")
244+
t.Log("\tWanted: 'text/kitteh'")
245+
t.Log("\tGot:", "'" + resp.Header.Get("Content-Type") + "'")
246+
t.Fail()
247+
}
248+
}
249+
250+
func testWithPrefix(t *testing.T, ts *httptest.Server, c *http.Client, prefix string) {
251+
var URL string
252+
t.Run("submit with prefix " + prefix, func(t *testing.T) {
253+
URL = string(doPOST(t, c, ts.URL + prefix + "/meow.txt", "text/plain", strings.NewReader(file)))
254+
})
255+
256+
if !strings.Contains(URL, prefix) {
257+
t.Errorf("Result URL doesn't contain prefix %v: %v", prefix, URL)
258+
t.FailNow()
259+
}
260+
261+
if URL != "" {
262+
t.Run("get with " + prefix, func(t *testing.T) {
263+
body := doGET(t, c, URL)
264+
if string(body) != file {
265+
t.Error("Got different file!")
266+
t.FailNow()
267+
}
268+
})
269+
}
270+
}
271+
272+
func TestPrefixAgnostic(t *testing.T) {
273+
// Server should be able to handle requests independently
274+
// from full URL.
275+
serv := initServ(filedrop.Default)
276+
ts := httptest.NewServer(serv)
277+
defer os.RemoveAll(serv.Conf.StorageDir)
278+
defer serv.Close()
279+
defer ts.Close()
280+
c := ts.Client()
281+
282+
testWithPrefix(t, ts, c, "/a/b/c/d/e/f/g")
283+
testWithPrefix(t, ts, c, "/a/f%20oo/g")
284+
testWithPrefix(t, ts, c, "")
285+
}

0 commit comments

Comments
 (0)