Skip to content

Commit a8de3de

Browse files
committed
age: add ExtractHeader, DecryptHeader, and NewInjectedFileKeyIdentity
1 parent ae74b61 commit a8de3de

File tree

2 files changed

+117
-10
lines changed

2 files changed

+117
-10
lines changed

age.go

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
package age
4747

4848
import (
49+
"bytes"
4950
"crypto/hmac"
5051
"crypto/rand"
5152
"errors"
@@ -207,22 +208,37 @@ func (*NoIdentityMatchError) Error() string {
207208
// If no identity matches the encrypted file, the returned error will be of type
208209
// [NoIdentityMatchError].
209210
func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
210-
if len(identities) == 0 {
211-
return nil, errors.New("no identities specified")
212-
}
213-
214211
hdr, payload, err := format.Parse(src)
215212
if err != nil {
216213
return nil, fmt.Errorf("failed to read header: %w", err)
217214
}
218215

216+
fileKey, err := decryptHdr(hdr, identities...)
217+
if err != nil {
218+
return nil, err
219+
}
220+
221+
nonce := make([]byte, streamNonceSize)
222+
if _, err := io.ReadFull(payload, nonce); err != nil {
223+
return nil, fmt.Errorf("failed to read nonce: %w", err)
224+
}
225+
226+
return stream.NewReader(streamKey(fileKey, nonce), payload)
227+
}
228+
229+
func decryptHdr(hdr *format.Header, identities ...Identity) ([]byte, error) {
230+
if len(identities) == 0 {
231+
return nil, errors.New("no identities specified")
232+
}
233+
219234
stanzas := make([]*Stanza, 0, len(hdr.Recipients))
220235
for _, s := range hdr.Recipients {
221236
stanzas = append(stanzas, (*Stanza)(s))
222237
}
223238
errNoMatch := &NoIdentityMatchError{}
224239
var fileKey []byte
225240
for _, id := range identities {
241+
var err error
226242
fileKey, err = id.Unwrap(stanzas)
227243
if errors.Is(err, ErrIncorrectIdentity) {
228244
errNoMatch.Errors = append(errNoMatch.Errors, err)
@@ -244,12 +260,7 @@ func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
244260
return nil, errors.New("bad header MAC")
245261
}
246262

247-
nonce := make([]byte, streamNonceSize)
248-
if _, err := io.ReadFull(payload, nonce); err != nil {
249-
return nil, fmt.Errorf("failed to read nonce: %w", err)
250-
}
251-
252-
return stream.NewReader(streamKey(fileKey, nonce), payload)
263+
return fileKey, nil
253264
}
254265

255266
// multiUnwrap is a helper that implements Identity.Unwrap in terms of a
@@ -270,3 +281,56 @@ func multiUnwrap(unwrap func(*Stanza) ([]byte, error), stanzas []*Stanza) ([]byt
270281
}
271282
return nil, ErrIncorrectIdentity
272283
}
284+
285+
// ExtractHeader returns a detached header from the src file.
286+
//
287+
// The detached header can be decrypted with [DecryptHeader] (for example on a
288+
// different system, without sharing the ciphertext) and then the file key can
289+
// be used with [NewInjectedFileKeyIdentity].
290+
//
291+
// This is a low-level function that most users won't need.
292+
func ExtractHeader(src io.Reader) ([]byte, error) {
293+
hdr, _, err := format.Parse(src)
294+
if err != nil {
295+
return nil, fmt.Errorf("failed to read header: %w", err)
296+
}
297+
buf := &bytes.Buffer{}
298+
if err := hdr.Marshal(buf); err != nil {
299+
return nil, fmt.Errorf("failed to serialize header: %w", err)
300+
}
301+
return buf.Bytes(), nil
302+
}
303+
304+
// DecryptHeader decrypts a detached header and returns a file key.
305+
//
306+
// The detached header can be produced by [ExtractHeader], and the
307+
// returned file key can be used with [NewInjectedFileKeyIdentity].
308+
//
309+
// This is a low-level function that most users won't need.
310+
// It is the caller's responsibility to keep track of what file the
311+
// returned file key decrypts, and to ensure the file key is not used
312+
// for any other purpose.
313+
func DecryptHeader(header []byte, identities ...Identity) ([]byte, error) {
314+
hdr, _, err := format.Parse(bytes.NewReader(header))
315+
if err != nil {
316+
return nil, fmt.Errorf("failed to read header: %w", err)
317+
}
318+
return decryptHdr(hdr, identities...)
319+
}
320+
321+
type injectedFileKeyIdentity struct {
322+
fileKey []byte
323+
}
324+
325+
// NewInjectedFileKeyIdentity returns an [Identity] that always produces
326+
// a fixed file key, allowing the use of a file key obtained out-of-band,
327+
// for example via [DecryptHeader].
328+
//
329+
// This is a low-level function that most users won't need.
330+
func NewInjectedFileKeyIdentity(fileKey []byte) Identity {
331+
return injectedFileKeyIdentity{fileKey}
332+
}
333+
334+
func (i injectedFileKeyIdentity) Unwrap(stanzas []*Stanza) (fileKey []byte, err error) {
335+
return i.fileKey, nil
336+
}

age_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,46 @@ func TestLabels(t *testing.T) {
284284
t.Errorf("expected pqc+foo mixed with foo+pqc to work, got %v", err)
285285
}
286286
}
287+
288+
func TestDetachedHeader(t *testing.T) {
289+
i, err := age.GenerateX25519Identity()
290+
if err != nil {
291+
t.Fatal(err)
292+
}
293+
294+
buf := &bytes.Buffer{}
295+
w, err := age.Encrypt(buf, i.Recipient())
296+
if err != nil {
297+
t.Fatal(err)
298+
}
299+
if _, err := io.WriteString(w, helloWorld); err != nil {
300+
t.Fatal(err)
301+
}
302+
if err := w.Close(); err != nil {
303+
t.Fatal(err)
304+
}
305+
encrypted := buf.Bytes()
306+
307+
header, err := age.ExtractHeader(bytes.NewReader(encrypted))
308+
if err != nil {
309+
t.Fatal(err)
310+
}
311+
312+
fileKey, err := age.DecryptHeader(header, i)
313+
if err != nil {
314+
t.Fatal(err)
315+
}
316+
317+
identity := age.NewInjectedFileKeyIdentity(fileKey)
318+
out, err := age.Decrypt(bytes.NewReader(encrypted), identity)
319+
if err != nil {
320+
t.Fatal(err)
321+
}
322+
outBytes, err := io.ReadAll(out)
323+
if err != nil {
324+
t.Fatal(err)
325+
}
326+
if string(outBytes) != helloWorld {
327+
t.Errorf("wrong data: %q, expected %q", outBytes, helloWorld)
328+
}
329+
}

0 commit comments

Comments
 (0)