Skip to content

Commit 982136d

Browse files
committed
age: use native identities first in Decrypt
1 parent 0d42e7a commit 982136d

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

age.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ func (*NoIdentityMatchError) Error() string {
214214
//
215215
// It returns a Reader reading the decrypted plaintext of the age file read
216216
// from src. All identities will be tried until one successfully decrypts the file.
217+
// Native, non-interactive identities are tried before any other identities.
217218
//
218219
// If no identity matches the encrypted file, the returned error will be of type
219220
// [NoIdentityMatchError].
@@ -240,6 +241,24 @@ func decryptHdr(hdr *format.Header, identities ...Identity) ([]byte, error) {
240241
if len(identities) == 0 {
241242
return nil, errors.New("no identities specified")
242243
}
244+
slices.SortStableFunc(identities, func(a, b Identity) int {
245+
var aIsNative, bIsNative bool
246+
switch a.(type) {
247+
case *X25519Identity, *HybridIdentity, *ScryptIdentity:
248+
aIsNative = true
249+
}
250+
switch b.(type) {
251+
case *X25519Identity, *HybridIdentity, *ScryptIdentity:
252+
bIsNative = true
253+
}
254+
if aIsNative && !bIsNative {
255+
return -1
256+
}
257+
if !aIsNative && bIsNative {
258+
return 1
259+
}
260+
return 0
261+
})
243262

244263
stanzas := make([]*Stanza, 0, len(hdr.Recipients))
245264
for _, s := range hdr.Recipients {

age_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,50 @@ func TestLabels(t *testing.T) {
285285
}
286286
}
287287

288+
// testIdentity is a non-native identity that records if Unwrap is called.
289+
type testIdentity struct {
290+
called bool
291+
}
292+
293+
func (ti *testIdentity) Unwrap(stanzas []*age.Stanza) ([]byte, error) {
294+
ti.called = true
295+
return nil, age.ErrIncorrectIdentity
296+
}
297+
298+
func TestDecryptNativeIdentitiesFirst(t *testing.T) {
299+
correct, err := age.GenerateX25519Identity()
300+
if err != nil {
301+
t.Fatal(err)
302+
}
303+
unrelated, err := age.GenerateX25519Identity()
304+
if err != nil {
305+
t.Fatal(err)
306+
}
307+
308+
buf := &bytes.Buffer{}
309+
w, err := age.Encrypt(buf, correct.Recipient())
310+
if err != nil {
311+
t.Fatal(err)
312+
}
313+
if err := w.Close(); err != nil {
314+
t.Fatal(err)
315+
}
316+
317+
nonNative := &testIdentity{}
318+
319+
// Pass identities: unrelated native, non-native, correct native.
320+
// Native identities should be tried first, so correct should match
321+
// before nonNative is ever called.
322+
_, err = age.Decrypt(bytes.NewReader(buf.Bytes()), unrelated, nonNative, correct)
323+
if err != nil {
324+
t.Fatal(err)
325+
}
326+
327+
if nonNative.called {
328+
t.Error("non-native identity was called, but native identities should be tried first")
329+
}
330+
}
331+
288332
func TestDetachedHeader(t *testing.T) {
289333
i, err := age.GenerateX25519Identity()
290334
if err != nil {

0 commit comments

Comments
 (0)