Skip to content

Commit 75dda54

Browse files
committed
refactor snapshot solution into combined peerdb + udp conn db snapshot
1 parent cfb314d commit 75dda54

File tree

17 files changed

+261
-212
lines changed

17 files changed

+261
-212
lines changed

config/embedded/trakx.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,5 +101,5 @@ db:
101101
# Database backup interval. Leave blank to disable backups on interval.
102102
interval: 300s
103103

104-
# Backup file path. Leave blank for default of `$cache/db`
104+
# Combined backup snapshot file path. Leave blank for default of `$cache/db`
105105
path:

daemon/backup.go

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/crimist/trakx/config"
1414
"github.com/crimist/trakx/stats"
1515
"github.com/crimist/trakx/storage/database"
16+
"github.com/crimist/trakx/tracker/udp/connections"
1617
"github.com/pkg/errors"
1718
"go.uber.org/zap"
1819
)
@@ -71,8 +72,18 @@ func ImportBackup(conf *config.Configuration, reader io.Reader) (err error) {
7172
}
7273

7374
tee := io.TeeReader(reader, tmpFile)
74-
if err = validationDB.Restore(tee); err != nil {
75-
return errors.Wrap(err, "backup validation failed")
75+
connSnapshot, dbReader, err := splitCombinedSnapshot(tee)
76+
if err != nil {
77+
return errors.Wrap(err, "failed to parse combined snapshot")
78+
}
79+
if err = validationDB.Restore(dbReader); err != nil {
80+
return errors.Wrap(err, "database validation failed")
81+
}
82+
if len(connSnapshot) > 0 {
83+
connValidation := connections.NewConnections(0, time.Minute, 0)
84+
if err := connValidation.Unmarshal(connSnapshot); err != nil {
85+
return errors.Wrap(err, "connection validation failed")
86+
}
7687
}
7788
if err = tmpFile.Sync(); err != nil {
7889
return errors.Wrap(err, "failed to sync backup file")
@@ -90,8 +101,8 @@ func ImportBackup(conf *config.Configuration, reader io.Reader) (err error) {
90101
return nil
91102
}
92103

93-
func persistDatabase(db *database.Database, backupPath string) error {
94-
if err := streamSnapshotToSocket(db); err == nil {
104+
func persistSnapshot(db *database.Database, connDB *connections.Connections, backupPath string) error {
105+
if err := streamSnapshotToSocket(db, connDB); err == nil {
95106
zap.L().Debug("Persisted backup via socket stream")
96107
return nil
97108
}
@@ -100,7 +111,7 @@ func persistDatabase(db *database.Database, backupPath string) error {
100111
}
101112

102113
zap.L().Debug("Persisting backup via file", zap.String("path", backupPath))
103-
return database.WriteSnapshotFile(db, backupPath)
114+
return writeCombinedSnapshotFile(db, connDB, backupPath)
104115
}
105116

106117
func streamSnapshotFromDaemon(processID int, writer io.Writer) error {
@@ -149,7 +160,7 @@ func streamSnapshotFromDaemon(processID int, writer io.Writer) error {
149160
return nil
150161
}
151162

152-
func streamSnapshotToSocket(db *database.Database) error {
163+
func streamSnapshotToSocket(db *database.Database, connDB *connections.Connections) error {
153164
socketPath := backupSocketPath(os.Getpid())
154165
zap.L().Debug("Dialing backup socket", zap.String("path", socketPath))
155166
dialer := net.Dialer{Timeout: backupDialTimeout}
@@ -166,7 +177,7 @@ func streamSnapshotToSocket(db *database.Database) error {
166177
}
167178

168179
zap.L().Debug("Writing snapshot to socket")
169-
return db.Snapshot(conn)
180+
return writeCombinedSnapshot(conn, db, connDB)
170181
}
171182

172183
func exportBackupFile(conf *config.Configuration, writer io.Writer) error {
@@ -197,6 +208,43 @@ func exportBackupFile(conf *config.Configuration, writer io.Writer) error {
197208
return nil
198209
}
199210

211+
func writeCombinedSnapshotFile(db *database.Database, connDB *connections.Connections, backupPath string) (err error) {
212+
if backupPath == "" {
213+
return errors.New("backup path is empty")
214+
}
215+
216+
tmpDir := filepath.Dir(backupPath)
217+
tmpFile, err := os.CreateTemp(tmpDir, "trakx-db-*")
218+
if err != nil {
219+
return errors.Wrap(err, "failed to create temp backup file")
220+
}
221+
tmpPath := tmpFile.Name()
222+
defer func() {
223+
if err != nil {
224+
tmpFile.Close()
225+
os.Remove(tmpPath)
226+
}
227+
}()
228+
229+
if err = writeCombinedSnapshot(tmpFile, db, connDB); err != nil {
230+
return errors.Wrap(err, "failed to write combined snapshot")
231+
}
232+
if err = tmpFile.Sync(); err != nil {
233+
return errors.Wrap(err, "failed to sync backup file")
234+
}
235+
if err = tmpFile.Close(); err != nil {
236+
return errors.Wrap(err, "failed to close backup file")
237+
}
238+
if err = os.Rename(tmpPath, backupPath); err != nil {
239+
return errors.Wrap(err, "failed to replace backup file")
240+
}
241+
if err = os.Chmod(backupPath, backupFilePermissions); err != nil {
242+
return errors.Wrap(err, "failed to set backup file permissions")
243+
}
244+
245+
return nil
246+
}
247+
200248
func readProcessID(path string) (int, error) {
201249
zap.L().Debug("Reading process id file", zap.String("path", path))
202250
contents, err := os.ReadFile(path)

daemon/backup_test.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ import (
1515
"github.com/crimist/trakx/stats"
1616
"github.com/crimist/trakx/storage"
1717
"github.com/crimist/trakx/storage/database"
18+
"github.com/crimist/trakx/tracker/udp/connections"
1819
)
1920

2021
func TestStreamSnapshotToSocket(t *testing.T) {
2122
db, err := database.NewDatabase(database.Config{
2223
InitalSize: 1,
23-
PersistanceAddress: "",
2424
Collector: stats.NewCollectors(false, false, 0),
2525
})
2626
if err != nil {
@@ -33,6 +33,10 @@ func TestStreamSnapshotToSocket(t *testing.T) {
3333
copy(peerID[:], bytes.Repeat([]byte{2}, len(peerID)))
3434
db.PeerAdd(hash, peerID, netip.MustParseAddr("127.0.0.1"), 1234, true)
3535

36+
connDB := connections.NewConnections(1, time.Minute, 0)
37+
udpAddr := netip.MustParseAddrPort("1.1.1.1:1234")
38+
connID := connDB.Create(udpAddr)
39+
3640
socketPath := backupSocketPath(os.Getpid())
3741
if err := removeSocketPath(socketPath); err != nil {
3842
t.Fatal(err)
@@ -67,28 +71,38 @@ func TestStreamSnapshotToSocket(t *testing.T) {
6771
dataCh <- data
6872
}()
6973

70-
if err := streamSnapshotToSocket(db); err != nil {
74+
if err := streamSnapshotToSocket(db, connDB); err != nil {
7175
t.Fatalf("streamSnapshotToSocket failed: %v", err)
7276
}
7377

7478
select {
7579
case err := <-errCh:
7680
t.Fatalf("listener error: %v", err)
7781
case data := <-dataCh:
82+
connSnapshot, dbReader, err := splitCombinedSnapshot(bytes.NewReader(data))
83+
if err != nil {
84+
t.Fatalf("splitCombinedSnapshot failed: %v", err)
85+
}
7886
restored, err := database.NewDatabase(database.Config{
7987
InitalSize: 1,
80-
PersistanceAddress: "",
8188
Collector: stats.NewCollectors(false, false, 0),
8289
})
8390
if err != nil {
8491
t.Fatal("Failed to create database")
8592
}
86-
if err := restored.Restore(bytes.NewReader(data)); err != nil {
93+
if err := restored.Restore(dbReader); err != nil {
8794
t.Fatalf("Restore failed: %v", err)
8895
}
8996
if restored.Torrents() != 1 {
9097
t.Fatalf("torrents = %d, want 1", restored.Torrents())
9198
}
99+
restoredConnections := connections.NewConnections(1, time.Minute, 0)
100+
if err := restoredConnections.Unmarshal(connSnapshot); err != nil {
101+
t.Fatalf("Failed to restore connections: %v", err)
102+
}
103+
if !restoredConnections.Validate(udpAddr, connID) {
104+
t.Fatalf("restored connections missing expected entry")
105+
}
92106
case <-time.After(backupAcceptTimeout + time.Second):
93107
t.Fatal("timed out waiting for snapshot")
94108
}
@@ -128,10 +142,10 @@ func TestImportBackupRejectsInvalidData(t *testing.T) {
128142
conf.DB.Backup.Path = backupPath
129143

130144
invalid := bytes.NewBuffer(nil)
131-
if _, err := invalid.WriteString("TRAKXDB"); err != nil {
145+
if _, err := invalid.WriteString(combinedSnapshotMagic); err != nil {
132146
t.Fatal(err)
133147
}
134-
if err := binary.Write(invalid, binary.LittleEndian, uint16(2)); err != nil {
148+
if err := binary.Write(invalid, binary.LittleEndian, uint16(combinedSnapshotVersion+1)); err != nil {
135149
t.Fatal(err)
136150
}
137151

daemon/combined_snapshot.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package daemon
2+
3+
import (
4+
"bufio"
5+
"encoding/binary"
6+
"errors"
7+
"io"
8+
9+
"github.com/crimist/trakx/storage/database"
10+
"github.com/crimist/trakx/tracker/udp/connections"
11+
)
12+
13+
const (
14+
combinedSnapshotMagic = "TRAKXSNAP"
15+
combinedSnapshotVersion = uint16(1)
16+
combinedSnapshotMagicSize = len(combinedSnapshotMagic)
17+
)
18+
19+
var (
20+
errCombinedSnapshotMagic = errors.New("invalid combined snapshot magic")
21+
errCombinedSnapshotVersion = errors.New("unsupported combined snapshot version")
22+
errCombinedSnapshotSize = errors.New("connections snapshot too large")
23+
)
24+
25+
type snapshotReader struct {
26+
*bufio.Reader
27+
closer io.Closer
28+
}
29+
30+
func (reader *snapshotReader) Close() error {
31+
if reader.closer != nil {
32+
return reader.closer.Close()
33+
}
34+
return nil
35+
}
36+
37+
func splitCombinedSnapshot(reader io.Reader) ([]byte, io.Reader, error) {
38+
bufReader := bufio.NewReader(reader)
39+
40+
header := make([]byte, combinedSnapshotMagicSize)
41+
if _, err := io.ReadFull(bufReader, header); err != nil {
42+
return nil, nil, err
43+
}
44+
if string(header) != combinedSnapshotMagic {
45+
return nil, nil, errCombinedSnapshotMagic
46+
}
47+
48+
var version uint16
49+
if err := binary.Read(bufReader, binary.LittleEndian, &version); err != nil {
50+
return nil, nil, err
51+
}
52+
if version != combinedSnapshotVersion {
53+
return nil, nil, errCombinedSnapshotVersion
54+
}
55+
56+
var connLen uint64
57+
if err := binary.Read(bufReader, binary.LittleEndian, &connLen); err != nil {
58+
return nil, nil, err
59+
}
60+
61+
maxInt := int(^uint(0) >> 1)
62+
if connLen > uint64(maxInt) {
63+
return nil, nil, errCombinedSnapshotSize
64+
}
65+
66+
var connBytes []byte
67+
if connLen > 0 {
68+
connBytes = make([]byte, int(connLen))
69+
if _, err := io.ReadFull(bufReader, connBytes); err != nil {
70+
return nil, nil, err
71+
}
72+
}
73+
74+
wrapped := &snapshotReader{Reader: bufReader}
75+
if closer, ok := reader.(io.Closer); ok {
76+
wrapped.closer = closer
77+
}
78+
79+
return connBytes, wrapped, nil
80+
}
81+
82+
func writeCombinedSnapshot(writer io.Writer, db *database.Database, connDB *connections.Connections) error {
83+
bufWriter := bufio.NewWriter(writer)
84+
85+
if _, err := bufWriter.WriteString(combinedSnapshotMagic); err != nil {
86+
return err
87+
}
88+
if err := binary.Write(bufWriter, binary.LittleEndian, combinedSnapshotVersion); err != nil {
89+
return err
90+
}
91+
92+
var connBytes []byte
93+
if connDB != nil {
94+
var err error
95+
connBytes, err = connDB.Marshal()
96+
if err != nil {
97+
return err
98+
}
99+
}
100+
if err := binary.Write(bufWriter, binary.LittleEndian, uint64(len(connBytes))); err != nil {
101+
return err
102+
}
103+
if len(connBytes) > 0 {
104+
if _, err := bufWriter.Write(connBytes); err != nil {
105+
return err
106+
}
107+
}
108+
if err := bufWriter.Flush(); err != nil {
109+
return err
110+
}
111+
112+
return db.Snapshot(writer)
113+
}

0 commit comments

Comments
 (0)