@@ -743,6 +743,168 @@ func TestDecryptReaderAtTruncatedChunk(t *testing.T) {
743743 }
744744}
745745
746+ func TestDecryptReaderAtConcurrent (t * testing.T ) {
747+ key := make ([]byte , chacha20poly1305 .KeySize )
748+ if _ , err := rand .Read (key ); err != nil {
749+ t .Fatal (err )
750+ }
751+
752+ // Create plaintext spanning 3 chunks: 2 full + partial
753+ plaintextSize := 2 * cs + 500
754+ plaintext := make ([]byte , plaintextSize )
755+ if _ , err := rand .Read (plaintext ); err != nil {
756+ t .Fatal (err )
757+ }
758+
759+ // Encrypt
760+ buf := & bytes.Buffer {}
761+ w , err := stream .NewEncryptWriter (key , buf )
762+ if err != nil {
763+ t .Fatal (err )
764+ }
765+ if _ , err := w .Write (plaintext ); err != nil {
766+ t .Fatal (err )
767+ }
768+ if err := w .Close (); err != nil {
769+ t .Fatal (err )
770+ }
771+ ciphertext := buf .Bytes ()
772+
773+ ra , err := stream .NewDecryptReaderAt (key , bytes .NewReader (ciphertext ), int64 (len (ciphertext )))
774+ if err != nil {
775+ t .Fatal (err )
776+ }
777+
778+ t .Run ("same chunk" , func (t * testing.T ) {
779+ t .Parallel ()
780+ const goroutines = 10
781+ const iterations = 100
782+ errc := make (chan error , goroutines )
783+
784+ for g := range goroutines {
785+ go func (id int ) {
786+ for i := range iterations {
787+ off := int64 ((id * iterations + i ) % 500 )
788+ p := make ([]byte , 100 )
789+ n , err := ra .ReadAt (p , off )
790+ if err != nil {
791+ errc <- fmt .Errorf ("goroutine %d iter %d: %v" , id , i , err )
792+ return
793+ }
794+ if n != 100 {
795+ errc <- fmt .Errorf ("goroutine %d iter %d: n=%d, want 100" , id , i , n )
796+ return
797+ }
798+ if ! bytes .Equal (p , plaintext [off :off + 100 ]) {
799+ errc <- fmt .Errorf ("goroutine %d iter %d: data mismatch" , id , i )
800+ return
801+ }
802+ }
803+ errc <- nil
804+ }(g )
805+ }
806+
807+ for range goroutines {
808+ if err := <- errc ; err != nil {
809+ t .Error (err )
810+ }
811+ }
812+ })
813+
814+ t .Run ("different chunks" , func (t * testing.T ) {
815+ t .Parallel ()
816+ const goroutines = 10
817+ const iterations = 100
818+ errc := make (chan error , goroutines )
819+
820+ for g := range goroutines {
821+ go func (id int ) {
822+ for i := range iterations {
823+ // Each goroutine reads from a different chunk based on id
824+ chunkIdx := id % 3
825+ off := int64 (chunkIdx * cs + (i % 400 ))
826+ size := 100
827+ if off + int64 (size ) > int64 (plaintextSize ) {
828+ size = plaintextSize - int (off )
829+ }
830+ p := make ([]byte , size )
831+ n , err := ra .ReadAt (p , off )
832+ if n == size && err == io .EOF {
833+ err = nil // EOF at end is acceptable
834+ }
835+ if err != nil {
836+ errc <- fmt .Errorf ("goroutine %d iter %d: off=%d: %v" , id , i , off , err )
837+ return
838+ }
839+ if n != size {
840+ errc <- fmt .Errorf ("goroutine %d iter %d: n=%d, want %d" , id , i , n , size )
841+ return
842+ }
843+ if ! bytes .Equal (p [:n ], plaintext [off :off + int64 (n )]) {
844+ errc <- fmt .Errorf ("goroutine %d iter %d: data mismatch" , id , i )
845+ return
846+ }
847+ }
848+ errc <- nil
849+ }(g )
850+ }
851+
852+ for range goroutines {
853+ if err := <- errc ; err != nil {
854+ t .Error (err )
855+ }
856+ }
857+ })
858+
859+ t .Run ("across chunks" , func (t * testing.T ) {
860+ t .Parallel ()
861+ const goroutines = 10
862+ const iterations = 100
863+ errc := make (chan error , goroutines )
864+
865+ for g := range goroutines {
866+ go func (id int ) {
867+ for i := range iterations {
868+ // Read across chunk boundaries
869+ boundary := (id % 2 + 1 ) * cs // either cs or 2*cs
870+ off := int64 (boundary - 50 + (i % 30 ))
871+ size := 100
872+ if off + int64 (size ) > int64 (plaintextSize ) {
873+ size = plaintextSize - int (off )
874+ }
875+ if size <= 0 {
876+ continue
877+ }
878+ p := make ([]byte , size )
879+ n , err := ra .ReadAt (p , off )
880+ if n == size && err == io .EOF {
881+ err = nil
882+ }
883+ if err != nil {
884+ errc <- fmt .Errorf ("goroutine %d iter %d: off=%d size=%d: %v" , id , i , off , size , err )
885+ return
886+ }
887+ if n != size {
888+ errc <- fmt .Errorf ("goroutine %d iter %d: n=%d, want %d" , id , i , n , size )
889+ return
890+ }
891+ if ! bytes .Equal (p [:n ], plaintext [off :off + int64 (n )]) {
892+ errc <- fmt .Errorf ("goroutine %d iter %d: data mismatch" , id , i )
893+ return
894+ }
895+ }
896+ errc <- nil
897+ }(g )
898+ }
899+
900+ for range goroutines {
901+ if err := <- errc ; err != nil {
902+ t .Error (err )
903+ }
904+ }
905+ })
906+ }
907+
746908func TestDecryptReaderAtCorrupted (t * testing.T ) {
747909 key := make ([]byte , chacha20poly1305 .KeySize )
748910 if _ , err := rand .Read (key ); err != nil {
0 commit comments