Skip to content

Commit f165191

Browse files
committed
feat: replace EmbedPublicKey by option
1 parent 9c7cc1d commit f165191

File tree

5 files changed

+59
-279
lines changed

5 files changed

+59
-279
lines changed

ipns/record.go

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/ipld/go-ipld-prime/datamodel"
1818
basicnode "github.com/ipld/go-ipld-prime/node/basic"
1919
ic "github.com/libp2p/go-libp2p/core/crypto"
20+
ic_pb "github.com/libp2p/go-libp2p/core/crypto/pb"
2021
"github.com/libp2p/go-libp2p/core/peer"
2122
"github.com/multiformats/go-multibase"
2223
"go.uber.org/multierr"
@@ -194,14 +195,21 @@ const (
194195
)
195196

196197
type options struct {
197-
compatibleWithV1 bool
198+
v1Compatibility bool
199+
embedPublicKey *bool
198200
}
199201

200202
type Option func(*options)
201203

202-
func CompatibleWithV1(compatible bool) Option {
203-
return func(opts *options) {
204-
opts.compatibleWithV1 = compatible
204+
func WithV1Compatibility(compatible bool) Option {
205+
return func(o *options) {
206+
o.v1Compatibility = compatible
207+
}
208+
}
209+
210+
func WithPublicKey(embedded bool) Option {
211+
return func(o *options) {
212+
o.embedPublicKey = &embedded
205213
}
206214
}
207215

@@ -243,7 +251,7 @@ func NewRecord(sk ic.PrivKey, value path.Path, seq uint64, eol time.Time, ttl ti
243251
SignatureV2: sig2,
244252
}
245253

246-
if options.compatibleWithV1 {
254+
if options.v1Compatibility {
247255
pb.Value = []byte(value)
248256
typ := ipns_pb.IpnsEntry_EOL
249257
pb.ValidityType = &typ
@@ -263,6 +271,17 @@ func NewRecord(sk ic.PrivKey, value path.Path, seq uint64, eol time.Time, ttl ti
263271
pb.SignatureV1 = sig1
264272
}
265273

274+
// By default, embed public key if it's not a Ed25519 key. Otherwise, only if
275+
// the user has explicitly asked for it to be embedded.
276+
if (options.embedPublicKey == nil && sk.Type() != ic_pb.KeyType_Ed25519) ||
277+
(options.embedPublicKey != nil && *options.embedPublicKey) {
278+
pkBytes, err := ic.MarshalPublicKey(sk.GetPublic())
279+
if err != nil {
280+
return nil, err
281+
}
282+
pb.PubKey = pkBytes
283+
}
284+
266285
return &Record{
267286
pb: pb,
268287
node: node,
@@ -396,31 +415,6 @@ func compare(a, b *Record) (int, error) {
396415
return 0, nil
397416
}
398417

399-
// EmbedPublicKey embeds the given public key in the given [Record]. While not
400-
// strictly required, some nodes (e.g., DHT servers), may reject IPNS Records
401-
// that do not embed their public keys as they may not be able to validate them
402-
// efficiently.
403-
func EmbedPublicKey(r *Record, pk ic.PubKey) error {
404-
// Try extracting the public key from the ID. If we can, do not embed it.
405-
pid, err := peer.IDFromPublicKey(pk)
406-
if err != nil {
407-
return err
408-
}
409-
if _, err := pid.ExtractPublicKey(); err != peer.ErrNoPublicKey {
410-
// Either a *real* error or nil.
411-
return err
412-
}
413-
414-
// We failed to extract the public key from the peer ID, embed it.
415-
pkBytes, err := ic.MarshalPublicKey(pk)
416-
if err != nil {
417-
return err
418-
}
419-
420-
r.pb.PubKey = pkBytes
421-
return nil
422-
}
423-
424418
// ExtractPublicKey extracts a [crypto.PubKey] matching the given [peer.ID] from
425419
// the IPNS Record, if possible.
426420
func ExtractPublicKey(r *Record, pid peer.ID) (ic.PubKey, error) {

ipns/record_test.go

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ func TestNewRecord(t *testing.T) {
107107
t.Run("V1+V2 with option", func(t *testing.T) {
108108
t.Parallel()
109109

110-
rec := mustNewRecord(t, sk, testPath, seq, eol, ttl, CompatibleWithV1(true))
110+
rec := mustNewRecord(t, sk, testPath, seq, eol, ttl, WithV1Compatibility(true))
111111
require.NotEmpty(t, rec.pb.SignatureV1)
112112

113113
_, err := rec.PubKey()
@@ -116,51 +116,48 @@ func TestNewRecord(t *testing.T) {
116116
fieldsMatch(t, rec, testPath, seq, eol, ttl)
117117
fieldsMatchV1(t, rec, testPath, seq, eol, ttl)
118118
})
119-
}
120-
121-
func TestEmbedPublicKey(t *testing.T) {
122-
t.Parallel()
123-
124-
sk, pk, pid := mustKeyPair(t, ic.RSA)
125-
126-
seq := uint64(0)
127-
eol := time.Now().Add(time.Hour)
128-
ttl := time.Minute * 10
129119

130-
rec := mustNewRecord(t, sk, testPath, seq, eol, ttl)
120+
t.Run("Public key embedded by default for non-ed25519 keys", func(t *testing.T) {
121+
t.Parallel()
131122

132-
_, err := rec.PubKey()
133-
require.ErrorIs(t, err, ErrPublicKeyNotFound)
123+
for _, keyType := range []int{ic.RSA, ic.Secp256k1, ic.ECDSA} {
124+
sk, _, _ := mustKeyPair(t, keyType)
125+
rec := mustNewRecord(t, sk, testPath, seq, eol, ttl)
126+
fieldsMatch(t, rec, testPath, seq, eol, ttl)
134127

135-
err = EmbedPublicKey(rec, pk)
136-
require.NoError(t, err)
128+
pk, err := rec.PubKey()
129+
require.NoError(t, err)
130+
require.True(t, pk.Equals(sk.GetPublic()))
131+
}
132+
})
137133

138-
recPK, err := rec.PubKey()
139-
require.NoError(t, err)
134+
t.Run("Public key not embedded by default for ed25519 keys", func(t *testing.T) {
135+
t.Parallel()
140136

141-
recPID, err := peer.IDFromPublicKey(recPK)
142-
require.NoError(t, err)
137+
sk, _, _ := mustKeyPair(t, ic.Ed25519)
138+
rec := mustNewRecord(t, sk, testPath, seq, eol, ttl)
139+
fieldsMatch(t, rec, testPath, seq, eol, ttl)
143140

144-
require.Equal(t, pid, recPID)
141+
_, err := rec.PubKey()
142+
require.ErrorIs(t, err, ErrPublicKeyNotFound)
143+
})
145144
}
146145

147146
func TestExtractPublicKey(t *testing.T) {
148147
t.Parallel()
149148

150149
t.Run("Returns expected public key when embedded in Peer ID", func(t *testing.T) {
151150
sk, pk, pid := mustKeyPair(t, ic.Ed25519)
152-
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10)
151+
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10, WithPublicKey(false))
153152

154153
pk2, err := ExtractPublicKey(rec, pid)
155154
require.Nil(t, err)
156155
require.Equal(t, pk, pk2)
157156
})
158157

159-
t.Run("Returns expected public key when embedded in Record", func(t *testing.T) {
158+
t.Run("Returns expected public key when embedded in Record (by default)", func(t *testing.T) {
160159
sk, pk, pid := mustKeyPair(t, ic.RSA)
161160
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10)
162-
err := EmbedPublicKey(rec, pk)
163-
require.Nil(t, err)
164161

165162
pk2, err := ExtractPublicKey(rec, pid)
166163
require.Nil(t, err)
@@ -169,7 +166,7 @@ func TestExtractPublicKey(t *testing.T) {
169166

170167
t.Run("Errors when not embedded in Record or Peer ID", func(t *testing.T) {
171168
sk, _, pid := mustKeyPair(t, ic.RSA)
172-
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10)
169+
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10, WithPublicKey(false))
173170

174171
pk, err := ExtractPublicKey(rec, pid)
175172
require.Error(t, err)

ipns/validation_test.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ func TestOrdering(t *testing.T) {
7070
func TestValidator(t *testing.T) {
7171
t.Parallel()
7272

73-
check := func(t *testing.T, sk ic.PrivKey, keybook peerstore.KeyBook, key string, val []byte, eol time.Time, exp error) {
73+
check := func(t *testing.T, sk ic.PrivKey, keybook peerstore.KeyBook, key string, val []byte, eol time.Time, exp error, opts ...Option) {
7474
validator := Validator{keybook}
7575
data := val
7676
if data == nil {
7777
// do not call mustNewRecord because that validates the record!
78-
rec, err := NewRecord(sk, testPath, 1, eol, 0)
78+
rec, err := NewRecord(sk, testPath, 1, eol, 0, opts...)
7979
require.NoError(t, err)
8080
data = mustMarshal(t, rec)
8181
}
@@ -99,9 +99,10 @@ func TestValidator(t *testing.T) {
9999
check(t, sk, kb, RoutingKey(pid), nil, ts.Add(time.Hour*-1), ErrExpiredRecord)
100100
check(t, sk, kb, RoutingKey(pid), []byte("bad data"), ts.Add(time.Hour), ErrBadRecord)
101101
check(t, sk, kb, "/ipns/"+"bad key", nil, ts.Add(time.Hour), ErrKeyFormat)
102-
check(t, sk, emptyKB, RoutingKey(pid), nil, ts.Add(time.Hour), ErrPublicKeyNotFound)
103-
check(t, sk2, kb, RoutingKey(pid2), nil, ts.Add(time.Hour), ErrPublicKeyNotFound)
104-
check(t, sk2, kb, RoutingKey(pid), nil, ts.Add(time.Hour), ErrSignature)
102+
check(t, sk, emptyKB, RoutingKey(pid), nil, ts.Add(time.Hour), ErrPublicKeyNotFound, WithPublicKey(false))
103+
check(t, sk2, kb, RoutingKey(pid2), nil, ts.Add(time.Hour), ErrPublicKeyNotFound, WithPublicKey(false))
104+
check(t, sk2, kb, RoutingKey(pid), nil, ts.Add(time.Hour), ErrPublicKeyMismatch)
105+
check(t, sk2, kb, RoutingKey(pid), nil, ts.Add(time.Hour), ErrSignature, WithPublicKey(false))
105106
check(t, sk, kb, "//"+string(pid), nil, ts.Add(time.Hour), ErrInvalidPath)
106107
check(t, sk, kb, "/wrong/"+string(pid), nil, ts.Add(time.Hour), ErrInvalidPath)
107108
})
@@ -128,14 +129,14 @@ func TestValidator(t *testing.T) {
128129
kb, err := pstoremem.NewPeerstore()
129130
require.NoError(t, err)
130131

131-
sk, pk, pid := mustKeyPair(t, ic.RSA)
132-
rec := mustNewRecord(t, sk, testPath, 1, eol, 0)
132+
sk, _, pid := mustKeyPair(t, ic.RSA)
133+
rec := mustNewRecord(t, sk, testPath, 1, eol, 0, WithPublicKey(false))
133134

134135
// Fails with RSA key without embedded public key.
135136
check(t, sk, kb, RoutingKey(pid), mustMarshal(t, rec), eol, ErrPublicKeyNotFound)
136137

137138
// Embeds public key, must work now.
138-
require.NoError(t, EmbedPublicKey(rec, pk))
139+
rec = mustNewRecord(t, sk, testPath, 1, eol, 0)
139140
check(t, sk, kb, RoutingKey(pid), mustMarshal(t, rec), eol, nil)
140141

141142
// Force bad public key. Validation fails.
@@ -163,8 +164,8 @@ func TestValidate(t *testing.T) {
163164

164165
v := Validator{}
165166

166-
rec1 := mustNewRecord(t, sk, path.FromString("/path/1"), 1, eol, 0, CompatibleWithV1(true))
167-
rec2 := mustNewRecord(t, sk, path.FromString("/path/2"), 2, eol, 0, CompatibleWithV1(true))
167+
rec1 := mustNewRecord(t, sk, path.FromString("/path/1"), 1, eol, 0, WithV1Compatibility(true))
168+
rec2 := mustNewRecord(t, sk, path.FromString("/path/2"), 2, eol, 0, WithV1Compatibility(true))
168169

169170
best, err := v.Select(ipnsRoutingKey, [][]byte{mustMarshal(t, rec1), mustMarshal(t, rec2)})
170171
require.NoError(t, err)

0 commit comments

Comments
 (0)