Skip to content

Commit f881db4

Browse files
committed
internal/stream: fix DecryptReaderAt concurrency
1 parent da21917 commit f881db4

File tree

2 files changed

+168
-1
lines changed

2 files changed

+168
-1
lines changed

internal/stream/stream.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
398398
if len(p) == 0 {
399399
return 0, nil
400400
}
401+
var cache *cachedChunk
401402
chunk := make([]byte, encChunkSize)
402403
for len(p) > 0 && off < r.size {
403404
chunkIndex := off / ChunkSize
@@ -409,6 +410,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
409410
var plaintext []byte
410411
if cached != nil && cached.off == chunkOff {
411412
plaintext = cached.data
413+
cache = nil
412414
} else {
413415
nn, err := r.src.ReadAt(chunk[:chunkSize], chunkOff)
414416
if err == io.EOF {
@@ -429,7 +431,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
429431
if err != nil {
430432
return n, fmt.Errorf("failed to decrypt and authenticate chunk at offset %d: %w", chunkOff, err)
431433
}
432-
r.cache.Store(&cachedChunk{off: chunkOff, data: plaintext})
434+
cache = &cachedChunk{off: chunkOff, data: plaintext}
433435
}
434436

435437
plainChunkOff := int(off - chunkIndex*ChunkSize)
@@ -439,6 +441,9 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
439441
off += int64(copySize)
440442
n += copySize
441443
}
444+
if cache != nil {
445+
r.cache.Store(cache)
446+
}
442447
if off == r.size {
443448
return n, io.EOF
444449
}

internal/stream/stream_test.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,168 @@ func TestDecryptReaderAtTruncatedChunk(t *testing.T) {
743743
}
744744
}
745745

746+
func TestDecryptReaderAtConcurrent(t *testing.T) {
747+
key := make([]byte, chacha20poly1305.KeySize)
748+
if _, err := rand.Read(key); err != nil {
749+
t.Fatal(err)
750+
}
751+
752+
// Create plaintext spanning 3 chunks: 2 full + partial
753+
plaintextSize := 2*cs + 500
754+
plaintext := make([]byte, plaintextSize)
755+
if _, err := rand.Read(plaintext); err != nil {
756+
t.Fatal(err)
757+
}
758+
759+
// Encrypt
760+
buf := &bytes.Buffer{}
761+
w, err := stream.NewEncryptWriter(key, buf)
762+
if err != nil {
763+
t.Fatal(err)
764+
}
765+
if _, err := w.Write(plaintext); err != nil {
766+
t.Fatal(err)
767+
}
768+
if err := w.Close(); err != nil {
769+
t.Fatal(err)
770+
}
771+
ciphertext := buf.Bytes()
772+
773+
ra, err := stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)))
774+
if err != nil {
775+
t.Fatal(err)
776+
}
777+
778+
t.Run("same chunk", func(t *testing.T) {
779+
t.Parallel()
780+
const goroutines = 10
781+
const iterations = 100
782+
errc := make(chan error, goroutines)
783+
784+
for g := range goroutines {
785+
go func(id int) {
786+
for i := range iterations {
787+
off := int64((id*iterations + i) % 500)
788+
p := make([]byte, 100)
789+
n, err := ra.ReadAt(p, off)
790+
if err != nil {
791+
errc <- fmt.Errorf("goroutine %d iter %d: %v", id, i, err)
792+
return
793+
}
794+
if n != 100 {
795+
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want 100", id, i, n)
796+
return
797+
}
798+
if !bytes.Equal(p, plaintext[off:off+100]) {
799+
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
800+
return
801+
}
802+
}
803+
errc <- nil
804+
}(g)
805+
}
806+
807+
for range goroutines {
808+
if err := <-errc; err != nil {
809+
t.Error(err)
810+
}
811+
}
812+
})
813+
814+
t.Run("different chunks", func(t *testing.T) {
815+
t.Parallel()
816+
const goroutines = 10
817+
const iterations = 100
818+
errc := make(chan error, goroutines)
819+
820+
for g := range goroutines {
821+
go func(id int) {
822+
for i := range iterations {
823+
// Each goroutine reads from a different chunk based on id
824+
chunkIdx := id % 3
825+
off := int64(chunkIdx*cs + (i % 400))
826+
size := 100
827+
if off+int64(size) > int64(plaintextSize) {
828+
size = plaintextSize - int(off)
829+
}
830+
p := make([]byte, size)
831+
n, err := ra.ReadAt(p, off)
832+
if n == size && err == io.EOF {
833+
err = nil // EOF at end is acceptable
834+
}
835+
if err != nil {
836+
errc <- fmt.Errorf("goroutine %d iter %d: off=%d: %v", id, i, off, err)
837+
return
838+
}
839+
if n != size {
840+
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want %d", id, i, n, size)
841+
return
842+
}
843+
if !bytes.Equal(p[:n], plaintext[off:off+int64(n)]) {
844+
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
845+
return
846+
}
847+
}
848+
errc <- nil
849+
}(g)
850+
}
851+
852+
for range goroutines {
853+
if err := <-errc; err != nil {
854+
t.Error(err)
855+
}
856+
}
857+
})
858+
859+
t.Run("across chunks", func(t *testing.T) {
860+
t.Parallel()
861+
const goroutines = 10
862+
const iterations = 100
863+
errc := make(chan error, goroutines)
864+
865+
for g := range goroutines {
866+
go func(id int) {
867+
for i := range iterations {
868+
// Read across chunk boundaries
869+
boundary := (id%2 + 1) * cs // either cs or 2*cs
870+
off := int64(boundary - 50 + (i % 30))
871+
size := 100
872+
if off+int64(size) > int64(plaintextSize) {
873+
size = plaintextSize - int(off)
874+
}
875+
if size <= 0 {
876+
continue
877+
}
878+
p := make([]byte, size)
879+
n, err := ra.ReadAt(p, off)
880+
if n == size && err == io.EOF {
881+
err = nil
882+
}
883+
if err != nil {
884+
errc <- fmt.Errorf("goroutine %d iter %d: off=%d size=%d: %v", id, i, off, size, err)
885+
return
886+
}
887+
if n != size {
888+
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want %d", id, i, n, size)
889+
return
890+
}
891+
if !bytes.Equal(p[:n], plaintext[off:off+int64(n)]) {
892+
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
893+
return
894+
}
895+
}
896+
errc <- nil
897+
}(g)
898+
}
899+
900+
for range goroutines {
901+
if err := <-errc; err != nil {
902+
t.Error(err)
903+
}
904+
}
905+
})
906+
}
907+
746908
func TestDecryptReaderAtCorrupted(t *testing.T) {
747909
key := make([]byte, chacha20poly1305.KeySize)
748910
if _, err := rand.Read(key); err != nil {

0 commit comments

Comments
 (0)