Skip to content

Commit 2ff5d34

Browse files
committed
age: add DecryptReaderAt
1 parent abe371e commit 2ff5d34

File tree

6 files changed

+1101
-119
lines changed

6 files changed

+1101
-119
lines changed

age.go

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,13 +252,12 @@ func (e *NoIdentityMatchError) Unwrap() []error {
252252
}
253253

254254
// Decrypt decrypts a file encrypted to one or more identities.
255-
//
256-
// It returns a Reader reading the decrypted plaintext of the age file read
257-
// from src. All identities will be tried until one successfully decrypts the file.
255+
// All identities will be tried until one successfully decrypts the file.
258256
// Native, non-interactive identities are tried before any other identities.
259257
//
260-
// If no identity matches the encrypted file, the returned error will be of type
261-
// [NoIdentityMatchError].
258+
// Decrypt returns a Reader reading the decrypted plaintext of the age file read
259+
// from src. If no identity matches the encrypted file, the returned error will
260+
// be of type [NoIdentityMatchError].
262261
func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
263262
hdr, payload, err := format.Parse(src)
264263
if err != nil {
@@ -278,6 +277,58 @@ func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
278277
return stream.NewDecryptReader(streamKey(fileKey, nonce), payload)
279278
}
280279

280+
// DecryptReaderAt decrypts a file encrypted to one or more identities.
281+
// All identities will be tried until one successfully decrypts the file.
282+
// Native, non-interactive identities are tried before any other identities.
283+
//
284+
// DecryptReaderAt takes an underlying [io.ReaderAt] and its total encrypted
285+
// size, and returns a ReaderAt of the decrypted plaintext and the plaintext
286+
// size. These can be used for example to instantiate an [io.SectionReader],
287+
// which implements [io.Reader] and [io.Seeker]. Note that ReaderAt by
288+
// definition disregards the seek position of src.
289+
//
290+
// The ReadAt method of the returned ReaderAt can be called concurrently.
291+
// The ReaderAt will internally cache the most recently decrypted chunk.
292+
// DecryptReaderAt reads and decrypts the final chunk before returning,
293+
// to authenticate the plaintext size.
294+
//
295+
// If no identity matches the encrypted file, the returned error will be of
296+
// type [NoIdentityMatchError].
297+
func DecryptReaderAt(src io.ReaderAt, encryptedSize int64, identities ...Identity) (io.ReaderAt, int64, error) {
298+
srcReader := io.NewSectionReader(src, 0, encryptedSize)
299+
hdr, payload, err := format.Parse(srcReader)
300+
if err != nil {
301+
return nil, 0, fmt.Errorf("failed to read header: %w", err)
302+
}
303+
buf := &bytes.Buffer{}
304+
if err := hdr.Marshal(buf); err != nil {
305+
return nil, 0, fmt.Errorf("failed to serialize header: %w", err)
306+
}
307+
308+
fileKey, err := decryptHdr(hdr, identities...)
309+
if err != nil {
310+
return nil, 0, err
311+
}
312+
313+
nonce := make([]byte, streamNonceSize)
314+
if _, err := io.ReadFull(payload, nonce); err != nil {
315+
return nil, 0, fmt.Errorf("failed to read nonce: %w", err)
316+
}
317+
318+
payloadOffset := int64(buf.Len()) + int64(len(nonce))
319+
payloadSize := encryptedSize - payloadOffset
320+
plaintextSize, err := stream.PlaintextSize(payloadSize)
321+
if err != nil {
322+
return nil, 0, err
323+
}
324+
payloadReaderAt := io.NewSectionReader(src, payloadOffset, payloadSize)
325+
r, err := stream.NewDecryptReaderAt(streamKey(fileKey, nonce), payloadReaderAt, payloadSize)
326+
if err != nil {
327+
return nil, 0, err
328+
}
329+
return r, plaintextSize, nil
330+
}
331+
281332
func decryptHdr(hdr *format.Header, identities ...Identity) ([]byte, error) {
282333
if len(identities) == 0 {
283334
return nil, errors.New("no identities specified")

internal/inspect/inspect.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"filippo.io/age/armor"
1111
"filippo.io/age/internal/format"
1212
"filippo.io/age/internal/stream"
13-
"golang.org/x/crypto/chacha20poly1305"
1413
)
1514

1615
type Metadata struct {
@@ -88,9 +87,9 @@ func Inspect(r io.Reader, fileSize int64) (*Metadata, error) {
8887
}
8988
data.Sizes.Armor = tr.count - fileSize
9089
}
91-
data.Sizes.Overhead = streamOverhead(fileSize - data.Sizes.Header)
92-
if data.Sizes.Overhead > fileSize-data.Sizes.Header {
93-
return nil, fmt.Errorf("payload too small to be a valid age file")
90+
data.Sizes.Overhead, err = streamOverhead(fileSize - data.Sizes.Header)
91+
if err != nil {
92+
return nil, fmt.Errorf("failed to compute stream overhead: %w", err)
9493
}
9594
data.Sizes.MinPayload = fileSize - data.Sizes.Header - data.Sizes.Overhead
9695
data.Sizes.MaxPayload = data.Sizes.MinPayload
@@ -114,13 +113,15 @@ func (tr *trackReader) Read(p []byte) (int, error) {
114113
return n, err
115114
}
116115

117-
func streamOverhead(payloadSize int64) int64 {
116+
func streamOverhead(payloadSize int64) (int64, error) {
118117
const streamNonceSize = 16
119-
const encChunkSize = stream.ChunkSize + chacha20poly1305.Overhead
120-
payloadSize -= streamNonceSize
121-
if payloadSize <= 0 {
122-
return streamNonceSize
118+
if payloadSize < streamNonceSize {
119+
return 0, fmt.Errorf("encrypted size too small: %d", payloadSize)
120+
}
121+
encryptedSize := payloadSize - streamNonceSize
122+
plaintextSize, err := stream.PlaintextSize(encryptedSize)
123+
if err != nil {
124+
return 0, err
123125
}
124-
chunks := (payloadSize + encChunkSize - 1) / encChunkSize
125-
return streamNonceSize + chunks*chacha20poly1305.Overhead
126+
return payloadSize - plaintextSize, nil
126127
}

internal/inspect/inspect_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package inspect
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
7+
"filippo.io/age/internal/stream"
8+
)
9+
10+
func TestStreamOverhead(t *testing.T) {
11+
tests := []struct {
12+
payloadSize int64
13+
want int64
14+
wantErr bool
15+
}{
16+
{payloadSize: 0, wantErr: true},
17+
{payloadSize: 15, wantErr: true},
18+
{payloadSize: 16, wantErr: true},
19+
{payloadSize: 16 + 15, wantErr: true},
20+
{payloadSize: 16 + 16, want: 16 + 16}, // empty plaintext
21+
{payloadSize: 16 + 1 + 16, want: 16 + 16},
22+
{payloadSize: 16 + stream.ChunkSize + 16, want: 16 + 16},
23+
{payloadSize: 16 + stream.ChunkSize + 16 + 1, wantErr: true},
24+
{payloadSize: 16 + stream.ChunkSize + 16 + 15, wantErr: true},
25+
{payloadSize: 16 + stream.ChunkSize + 16 + 16, wantErr: true}, // empty final chunk
26+
{payloadSize: 16 + stream.ChunkSize + 16 + 1 + 16, want: 16 + 16 + 16},
27+
}
28+
for _, tt := range tests {
29+
name := "payloadSize=" + fmt.Sprint(tt.payloadSize)
30+
t.Run(name, func(t *testing.T) {
31+
got, gotErr := streamOverhead(tt.payloadSize)
32+
if gotErr != nil {
33+
if !tt.wantErr {
34+
t.Errorf("streamOverhead() failed: %v", gotErr)
35+
}
36+
return
37+
}
38+
if tt.wantErr {
39+
t.Fatal("streamOverhead() succeeded unexpectedly")
40+
}
41+
if got != tt.want {
42+
t.Errorf("streamOverhead() = %v, want %v", got, tt.want)
43+
}
44+
})
45+
}
46+
}

internal/stream/stream.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,42 @@ package stream
88
import (
99
"bytes"
1010
"crypto/cipher"
11+
"encoding/binary"
1112
"errors"
1213
"fmt"
1314
"io"
15+
"sync/atomic"
1416

1517
"golang.org/x/crypto/chacha20poly1305"
1618
)
1719

1820
const ChunkSize = 64 * 1024
1921

22+
func EncryptedChunkCount(encryptedSize int64) (int64, error) {
23+
chunks := (encryptedSize + encChunkSize - 1) / encChunkSize
24+
25+
plaintextSize := encryptedSize - chunks*chacha20poly1305.Overhead
26+
expChunks := (plaintextSize + ChunkSize - 1) / ChunkSize
27+
// Empty plaintext, the only case that allows (and requires) an empty chunk.
28+
if plaintextSize == 0 {
29+
expChunks = 1
30+
}
31+
if expChunks != chunks {
32+
return 0, fmt.Errorf("invalid encrypted payload size: %d", encryptedSize)
33+
}
34+
35+
return chunks, nil
36+
}
37+
38+
func PlaintextSize(encryptedSize int64) (int64, error) {
39+
chunks, err := EncryptedChunkCount(encryptedSize)
40+
if err != nil {
41+
return 0, err
42+
}
43+
plaintextSize := encryptedSize - chunks*chacha20poly1305.Overhead
44+
return plaintextSize, nil
45+
}
46+
2047
type DecryptReader struct {
2148
a cipher.AEAD
2249
src io.Reader
@@ -135,6 +162,12 @@ func incNonce(nonce *[chacha20poly1305.NonceSize]byte) {
135162
panic("stream: chunk counter wrapped around")
136163
}
137164

165+
func nonceForChunk(chunkIndex int64) *[chacha20poly1305.NonceSize]byte {
166+
var nonce [chacha20poly1305.NonceSize]byte
167+
binary.BigEndian.PutUint64(nonce[3:11], uint64(chunkIndex))
168+
return &nonce
169+
}
170+
138171
func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) {
139172
nonce[len(nonce)-1] = lastChunkFlag
140173
}
@@ -312,3 +345,102 @@ func (r *EncryptReader) feedBuffer() error {
312345

313346
return nil
314347
}
348+
349+
type DecryptReaderAt struct {
350+
a cipher.AEAD
351+
src io.ReaderAt
352+
size int64
353+
chunks int64
354+
cache atomic.Pointer[cachedChunk]
355+
}
356+
357+
type cachedChunk struct {
358+
off int64
359+
data []byte
360+
}
361+
362+
func NewDecryptReaderAt(key []byte, src io.ReaderAt, size int64) (*DecryptReaderAt, error) {
363+
aead, err := chacha20poly1305.New(key)
364+
if err != nil {
365+
return nil, err
366+
}
367+
368+
// Check that size is valid by decrypting the final chunk.
369+
chunks, err := EncryptedChunkCount(size)
370+
if err != nil {
371+
return nil, err
372+
}
373+
finalChunkIndex := chunks - 1
374+
finalChunkOff := finalChunkIndex * encChunkSize
375+
finalChunkSize := size - finalChunkOff
376+
finalChunk := make([]byte, finalChunkSize)
377+
if _, err := src.ReadAt(finalChunk, finalChunkOff); err != nil {
378+
return nil, fmt.Errorf("failed to read final chunk: %w", err)
379+
}
380+
nonce := nonceForChunk(finalChunkIndex)
381+
setLastChunkFlag(nonce)
382+
plaintext, err := aead.Open(finalChunk[:0], nonce[:], finalChunk, nil)
383+
if err != nil {
384+
return nil, fmt.Errorf("failed to decrypt and authenticate final chunk: %w", err)
385+
}
386+
cache := &cachedChunk{off: finalChunkOff, data: plaintext}
387+
388+
plaintextSize := size - chunks*chacha20poly1305.Overhead
389+
r := &DecryptReaderAt{a: aead, src: src, size: plaintextSize, chunks: chunks}
390+
r.cache.Store(cache)
391+
return r, nil
392+
}
393+
394+
func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
395+
if off < 0 || off > r.size {
396+
return 0, fmt.Errorf("offset out of range [0:%d]: %d", r.size, off)
397+
}
398+
if len(p) == 0 {
399+
return 0, nil
400+
}
401+
chunk := make([]byte, encChunkSize)
402+
for len(p) > 0 && off < r.size {
403+
chunkIndex := off / ChunkSize
404+
chunkOff := chunkIndex * encChunkSize
405+
encSize := r.size + r.chunks*chacha20poly1305.Overhead
406+
chunkSize := min(encSize-chunkOff, encChunkSize)
407+
408+
cached := r.cache.Load()
409+
var plaintext []byte
410+
if cached != nil && cached.off == chunkOff {
411+
plaintext = cached.data
412+
} else {
413+
nn, err := r.src.ReadAt(chunk[:chunkSize], chunkOff)
414+
if err == io.EOF {
415+
if int64(nn) != chunkSize {
416+
err = io.ErrUnexpectedEOF
417+
} else {
418+
err = nil
419+
}
420+
}
421+
if err != nil {
422+
return n, fmt.Errorf("failed to read chunk at offset %d: %w", chunkOff, err)
423+
}
424+
nonce := nonceForChunk(chunkIndex)
425+
if chunkIndex == r.chunks-1 {
426+
setLastChunkFlag(nonce)
427+
}
428+
plaintext, err = r.a.Open(chunk[:0], nonce[:], chunk[:chunkSize], nil)
429+
if err != nil {
430+
return n, fmt.Errorf("failed to decrypt and authenticate chunk at offset %d: %w", chunkOff, err)
431+
}
432+
r.cache.Store(&cachedChunk{off: chunkOff, data: plaintext})
433+
}
434+
435+
plainChunkOff := int(off - chunkIndex*ChunkSize)
436+
copySize := min(len(plaintext)-plainChunkOff, len(p))
437+
copy(p, plaintext[plainChunkOff:plainChunkOff+copySize])
438+
p = p[copySize:]
439+
off += int64(copySize)
440+
n += copySize
441+
}
442+
if off == r.size {
443+
return n, io.EOF
444+
}
445+
return n, nil
446+
}

0 commit comments

Comments
 (0)