Skip to content

Commit bcdd856

Browse files
authored
Merge branch 'main' into alert-autofix-6
2 parents f8808d2 + 323d8e8 commit bcdd856

File tree

6 files changed

+140
-10
lines changed

6 files changed

+140
-10
lines changed

.github/workflows/golangci-lint.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ concurrency:
1111
jobs:
1212
golangci-lint:
1313
name: golangci-lint
14+
permissions:
15+
contents: read
16+
pull-requests: write
1417
runs-on: ubuntu-latest
1518
steps:
1619
- name: Check out code into the Go module directory

store/memory_store.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package store
33
import (
44
"bytes"
55
"context"
6+
"encoding/binary"
67
"encoding/gob"
8+
"hash/crc32"
79
"io"
810
"log/slog"
911
"os"
@@ -209,22 +211,38 @@ func (s *memoryStore) Snapshot() (io.ReadWriter, error) {
209211
s.mtx.RUnlock()
210212

211213
buf := &bytes.Buffer{}
212-
err := gob.NewEncoder(buf).Encode(cl)
213-
if err != nil {
214+
if err := gob.NewEncoder(buf).Encode(cl); err != nil {
215+
return nil, errors.WithStack(err)
216+
}
217+
218+
sum := crc32.ChecksumIEEE(buf.Bytes())
219+
if err := binary.Write(buf, binary.LittleEndian, sum); err != nil {
214220
return nil, errors.WithStack(err)
215221
}
216222

217223
return buf, nil
218224
}
219-
func (s *memoryStore) Restore(buf io.Reader) error {
225+
func (s *memoryStore) Restore(r io.Reader) error {
220226
s.mtx.Lock()
221227
defer s.mtx.Unlock()
222228

223-
s.m = make(map[uint64][]byte)
224-
err := gob.NewDecoder(buf).Decode(&s.m)
229+
data, err := io.ReadAll(r)
225230
if err != nil {
226231
return errors.WithStack(err)
227232
}
233+
if len(data) < 4 {
234+
return errors.WithStack(ErrInvalidChecksum)
235+
}
236+
payload := data[:len(data)-4]
237+
expected := binary.LittleEndian.Uint32(data[len(data)-4:])
238+
if crc32.ChecksumIEEE(payload) != expected {
239+
return errors.WithStack(ErrInvalidChecksum)
240+
}
241+
242+
s.m = make(map[uint64][]byte)
243+
if err := gob.NewDecoder(bytes.NewReader(payload)).Decode(&s.m); err != nil {
244+
return errors.WithStack(err)
245+
}
228246

229247
return nil
230248
}

store/memory_store_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package store
22

33
import (
4+
"bytes"
45
"context"
56
"strconv"
67
"sync"
@@ -204,6 +205,46 @@ func TestMemoryStore_TTL(t *testing.T) {
204205
wg.Wait()
205206
}
206207

208+
func TestMemoryStore_SnapshotChecksum(t *testing.T) {
209+
t.Parallel()
210+
211+
t.Run("success", func(t *testing.T) {
212+
t.Parallel()
213+
ctx := context.Background()
214+
st := NewMemoryStore()
215+
assert.NoError(t, st.Put(ctx, []byte("foo"), []byte("bar")))
216+
217+
buf, err := st.Snapshot()
218+
assert.NoError(t, err)
219+
220+
st2 := NewMemoryStore()
221+
err = st2.Restore(bytes.NewReader(buf.(*bytes.Buffer).Bytes()))
222+
assert.NoError(t, err)
223+
224+
v, err := st2.Get(ctx, []byte("foo"))
225+
assert.NoError(t, err)
226+
assert.Equal(t, []byte("bar"), v)
227+
})
228+
229+
t.Run("corrupted", func(t *testing.T) {
230+
t.Parallel()
231+
ctx := context.Background()
232+
st := NewMemoryStore()
233+
assert.NoError(t, st.Put(ctx, []byte("foo"), []byte("bar")))
234+
235+
buf, err := st.Snapshot()
236+
assert.NoError(t, err)
237+
data := buf.(*bytes.Buffer).Bytes()
238+
corrupted := make([]byte, len(data))
239+
copy(corrupted, data)
240+
corrupted[0] ^= 0xff
241+
242+
st2 := NewMemoryStore()
243+
err = st2.Restore(bytes.NewReader(corrupted))
244+
assert.ErrorIs(t, err, ErrInvalidChecksum)
245+
})
246+
}
247+
207248
func TestMemoryStore_TTL_Txn(t *testing.T) {
208249
ctx := context.Background()
209250
t.Parallel()

store/rb_memory_store.go

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package store
33
import (
44
"bytes"
55
"context"
6+
"encoding/binary"
67
"encoding/gob"
8+
"hash/crc32"
79
"io"
810
"log/slog"
911
"os"
@@ -229,22 +231,46 @@ func (s *rbMemoryStore) Snapshot() (io.ReadWriter, error) {
229231
s.mtx.RUnlock()
230232

231233
buf := &bytes.Buffer{}
232-
err := gob.NewEncoder(buf).Encode(cl)
233-
if err != nil {
234+
if err := gob.NewEncoder(buf).Encode(cl); err != nil {
235+
return nil, errors.WithStack(err)
236+
}
237+
238+
sum := crc32.ChecksumIEEE(buf.Bytes())
239+
if err := binary.Write(buf, binary.LittleEndian, sum); err != nil {
234240
return nil, errors.WithStack(err)
235241
}
236242

237243
return buf, nil
238244
}
239-
func (s *rbMemoryStore) Restore(buf io.Reader) error {
245+
func (s *rbMemoryStore) Restore(r io.Reader) error {
240246
s.mtx.Lock()
241247
defer s.mtx.Unlock()
242248

243-
s.tree.Clear()
244-
err := gob.NewDecoder(buf).Decode(&s.tree)
249+
data, err := io.ReadAll(r)
245250
if err != nil {
246251
return errors.WithStack(err)
247252
}
253+
if len(data) < 4 {
254+
return errors.WithStack(ErrInvalidChecksum)
255+
}
256+
payload := data[:len(data)-4]
257+
expected := binary.LittleEndian.Uint32(data[len(data)-4:])
258+
if crc32.ChecksumIEEE(payload) != expected {
259+
return errors.WithStack(ErrInvalidChecksum)
260+
}
261+
262+
var cl map[*[]byte][]byte
263+
if err := gob.NewDecoder(bytes.NewReader(payload)).Decode(&cl); err != nil {
264+
return errors.WithStack(err)
265+
}
266+
267+
s.tree.Clear()
268+
for k, v := range cl {
269+
if k == nil {
270+
continue
271+
}
272+
s.tree.Put(*k, v)
273+
}
248274

249275
return nil
250276
}

store/rb_memory_store_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package store
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/binary"
67
"strconv"
@@ -205,6 +206,46 @@ func TestRbMemoryStore_Txn(t *testing.T) {
205206
})
206207
}
207208

209+
func TestRbMemoryStore_SnapshotChecksum(t *testing.T) {
210+
t.Parallel()
211+
212+
t.Run("success", func(t *testing.T) {
213+
t.Parallel()
214+
ctx := context.Background()
215+
st := NewRbMemoryStore()
216+
assert.NoError(t, st.Put(ctx, []byte("foo"), []byte("bar")))
217+
218+
buf, err := st.Snapshot()
219+
assert.NoError(t, err)
220+
221+
st2 := NewRbMemoryStore()
222+
err = st2.Restore(bytes.NewReader(buf.(*bytes.Buffer).Bytes()))
223+
assert.NoError(t, err)
224+
225+
v, err := st2.Get(ctx, []byte("foo"))
226+
assert.NoError(t, err)
227+
assert.Equal(t, []byte("bar"), v)
228+
})
229+
230+
t.Run("corrupted", func(t *testing.T) {
231+
t.Parallel()
232+
ctx := context.Background()
233+
st := NewRbMemoryStore()
234+
assert.NoError(t, st.Put(ctx, []byte("foo"), []byte("bar")))
235+
236+
buf, err := st.Snapshot()
237+
assert.NoError(t, err)
238+
data := buf.(*bytes.Buffer).Bytes()
239+
corrupted := make([]byte, len(data))
240+
copy(corrupted, data)
241+
corrupted[0] ^= 0xff
242+
243+
st2 := NewRbMemoryStore()
244+
err = st2.Restore(bytes.NewReader(corrupted))
245+
assert.ErrorIs(t, err, ErrInvalidChecksum)
246+
})
247+
}
248+
208249
func TestRbMemoryStore_TTL(t *testing.T) {
209250
ctx := context.Background()
210251
t.Parallel()

store/store.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
var ErrKeyNotFound = errors.New("not found")
1111
var ErrUnknownOp = errors.New("unknown op")
1212
var ErrNotSupported = errors.New("not supported")
13+
var ErrInvalidChecksum = errors.New("invalid checksum")
1314

1415
type KVPair struct {
1516
Key []byte

0 commit comments

Comments
 (0)