Skip to content

Commit 2bf017c

Browse files
committed
Do not require IHDScriptPubKey for some PSBT operations
1 parent a8a0bad commit 2bf017c

File tree

7 files changed

+62
-65
lines changed

7 files changed

+62
-65
lines changed

NBitcoin/BIP174/PSBTCoin.cs

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -111,53 +111,65 @@ public virtual void AddKeyPath(PubKey pubKey, RootedKeyPath rootedKeyPath)
111111
/// <param name="accountKey">The account key that will be used to sign (ie. 49'/0'/0')</param>
112112
/// <param name="accountKeyPath">The account key path</param>
113113
/// <returns>HD Keys matching master root key</returns>
114-
public IEnumerable<PSBTHDKeyMatch> HDKeysFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
114+
public IEnumerable<PSBTHDKeyMatch> HDKeysFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
115115
{
116116
if (accountKey == null)
117117
throw new ArgumentNullException(nameof(accountKey));
118-
if (accountHDScriptPubKey == null)
119-
throw new ArgumentNullException(nameof(accountHDScriptPubKey));
120118
return HDKeysFor(accountHDScriptPubKey, accountKey, accountKeyPath, accountKey.GetPublicKey().GetHDFingerPrint());
121119
}
122120
internal IEnumerable<PSBTHDKeyMatch> HDKeysFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath, HDFingerprint accountFingerprint)
123121
{
122+
var coinScriptPubKey = this.GetTxOut()?.ScriptPubKey;
123+
bool Match(KeyValuePair<IPubKey, RootedKeyPath> hdKey, HDFingerprint? expectedMasterFp, IHDKey accountKey, KeyPath addressPath)
124+
{
125+
if (expectedMasterFp is not null &&
126+
hdKey.Value.MasterFingerprint != expectedMasterFp.Value)
127+
return false;
128+
if (accountHDScriptPubKey is not null &&
129+
accountHDScriptPubKey.Derive(addressPath).ScriptPubKey != coinScriptPubKey)
130+
return false;
131+
var derived = accountKey.Derive(addressPath);
132+
return hdKey.Key switch
133+
{
134+
PubKey pk => derived.GetPublicKey().Equals(pk),
135+
#if HAS_SPAN
136+
TaprootPubKey tpk => derived.GetPublicKey().TaprootPubKey.Equals(tpk),
137+
#endif
138+
_ => false
139+
};
140+
}
124141
accountKey = accountKey.AsHDKeyCache();
125142
accountHDScriptPubKey = accountHDScriptPubKey?.AsHDKeyCache();
126-
var coinScriptPubKey = this.GetTxOut()?.ScriptPubKey;
127143
foreach (var hdKey in EnumerateKeyPaths())
128144
{
129145
bool matched = false;
130-
146+
var canDeriveHardenedPath = (accountKey.CanDeriveHardenedPath() && (accountHDScriptPubKey == null || accountHDScriptPubKey.CanDeriveHardenedPath()));
131147
// The case where the fingerprint of the hdkey is exactly equal to the accountKey
132-
if (hdKey.Value.MasterFingerprint == accountFingerprint)
148+
if (!hdKey.Value.KeyPath.IsHardenedPath || canDeriveHardenedPath)
133149
{
134-
// The fingerprint match, but we need to check the public keys, because fingerprint collision is easy to provoke
135-
if (!hdKey.Value.KeyPath.IsHardenedPath || (accountKey.CanDeriveHardenedPath() && (accountHDScriptPubKey == null || accountHDScriptPubKey.CanDeriveHardenedPath())))
150+
if (Match(hdKey, accountFingerprint, accountKey, hdKey.Value.KeyPath))
136151
{
137-
if (accountHDScriptPubKey == null || accountHDScriptPubKey.Derive(hdKey.Value.KeyPath).ScriptPubKey == coinScriptPubKey)
138-
{
139-
yield return CreateHDKeyMatch(accountKey, hdKey.Value.KeyPath, hdKey);
140-
matched = true;
141-
}
152+
yield return CreateHDKeyMatch(accountKey, hdKey.Value.KeyPath, hdKey);
153+
matched = true;
142154
}
143155
}
144156

145157
// The typical case where accountkey is based on an hardened derivation (eg. 49'/0'/0')
146-
if (!matched && accountKeyPath?.MasterFingerprint is HDFingerprint mp && hdKey.Value.MasterFingerprint == mp)
158+
if (!matched && accountKeyPath?.MasterFingerprint is HDFingerprint mp)
147159
{
148160
var addressPath = hdKey.Value.KeyPath.GetAddressKeyPath();
149161
// The cases where addresses are generated on a non-hardened path below it (eg. 49'/0'/0'/0/1)
150162
if (addressPath.Indexes.Length != 0)
151163
{
152-
if (accountHDScriptPubKey == null || accountHDScriptPubKey.Derive(addressPath).ScriptPubKey == coinScriptPubKey)
164+
if (Match(hdKey, mp, accountKey, addressPath))
153165
{
154166
yield return CreateHDKeyMatch(accountKey, addressPath, hdKey);
155167
matched = true;
156168
}
157169
}
158170
// in some cases addresses are generated on a hardened path below the account key (eg. 49'/0'/0'/0'/1') in which case we
159171
// need to brute force what the address key path is
160-
else if (accountKey.CanDeriveHardenedPath() && (accountHDScriptPubKey == null || accountHDScriptPubKey.CanDeriveHardenedPath())) // We can only do this if we can derive hardened paths
172+
else if (canDeriveHardenedPath) // We can only do this if we can derive hardened paths
161173
{
162174
int addressPathSize = 0;
163175
var hdKeyIndexes = hdKey.Value.KeyPath.Indexes;
@@ -166,10 +178,9 @@ internal IEnumerable<PSBTHDKeyMatch> HDKeysFor(IHDScriptPubKey? accountHDScriptP
166178
var indexes = new uint[addressPathSize];
167179
Array.Copy(hdKeyIndexes, hdKey.Value.KeyPath.Length - addressPathSize, indexes, 0, addressPathSize);
168180
addressPath = new KeyPath(indexes);
169-
if (accountKey.Derive(addressPath).GetPublicKey().Equals(hdKey.Key))
181+
if (Match(hdKey, null, accountKey, addressPath))
170182
{
171-
if (accountHDScriptPubKey == null || accountHDScriptPubKey.Derive(addressPath).ScriptPubKey == coinScriptPubKey)
172-
yield return CreateHDKeyMatch(accountKey, addressPath, hdKey);
183+
yield return CreateHDKeyMatch(accountKey, addressPath, hdKey);
173184
matched = true;
174185
break;
175186
}

NBitcoin/BIP174/PSBTCoinList.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ public class PSBTCoinList<T> : IReadOnlyList<T> where T : PSBTCoin
106106
/// <param name="accountKey">The account key that will be used to sign (ie. 49'/0'/0')</param>
107107
/// <param name="accountKeyPath">The account key path</param>
108108
/// <returns>Inputs with HD keys matching masterFingerprint and account key</returns>
109-
public IEnumerable<T> CoinsFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
109+
public IEnumerable<T> CoinsFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
110110
{
111111
return GetPSBTCoins(accountHDScriptPubKey, accountKey, accountKeyPath);
112112
}
@@ -119,7 +119,7 @@ public IEnumerable<T> CoinsFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey acc
119119
/// <param name="accountKey">The account key that will be used to sign (ie. 49'/0'/0')</param>
120120
/// <param name="accountKeyPath">The account key path</param>
121121
/// <returns>HD Keys matching master root key</returns>
122-
public IEnumerable<PSBTHDKeyMatch<T>> HDKeysFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
122+
public IEnumerable<PSBTHDKeyMatch<T>> HDKeysFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
123123
{
124124
return GetHDKeys(accountHDScriptPubKey, accountKey, accountKeyPath);
125125
}
@@ -135,7 +135,7 @@ public IEnumerable<PSBTHDKeyMatch<T>> HDKeysFor(IHDKey accountKey, RootedKeyPath
135135
return GetHDKeys(null, accountKey, accountKeyPath);
136136
}
137137

138-
internal IEnumerable<T> GetPSBTCoins(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
138+
internal IEnumerable<T> GetPSBTCoins(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
139139
{
140140
return GetHDKeys(accountHDScriptPubKey, accountKey, accountKeyPath)
141141
.Select(c => c.Coin)
@@ -147,7 +147,6 @@ internal IEnumerable<PSBTHDKeyMatch<T>> GetHDKeys(IHDScriptPubKey? hdScriptPubKe
147147
if (accountKey == null)
148148
throw new ArgumentNullException(nameof(accountKey));
149149
accountKey = accountKey.AsHDKeyCache();
150-
hdScriptPubKey = hdScriptPubKey?.AsHDKeyCache();
151150
var accountFingerprint = accountKey.GetPublicKey().GetHDFingerPrint();
152151
foreach (var c in this)
153152
{

NBitcoin/BIP174/PSBTInput.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -504,16 +504,13 @@ public void TrySign(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, Ro
504504
{
505505
TrySign(accountHDScriptPubKey, accountKey, accountKeyPath, null);
506506
}
507-
internal void TrySign(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath, SigningOptions? signingOptions)
507+
internal void TrySign(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath, SigningOptions? signingOptions)
508508
{
509509
if (accountKey == null)
510510
throw new ArgumentNullException(nameof(accountKey));
511-
if (accountHDScriptPubKey == null)
512-
throw new ArgumentNullException(nameof(accountHDScriptPubKey));
513511
if (IsFinalized())
514512
return;
515513
var cache = accountKey.AsHDKeyCache();
516-
accountHDScriptPubKey = accountHDScriptPubKey.AsHDKeyCache();
517514
foreach (var hdk in this.HDKeysFor(accountHDScriptPubKey, cache, accountKeyPath))
518515
{
519516
if (((HDKeyCache)cache.Derive(hdk.AddressKeyPath)).Inner is ISecret k)

NBitcoin/BIP174/PartiallySignedTransaction.cs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -530,13 +530,11 @@ public PSBT SignAll(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey)
530530
/// <param name="accountKey">The account key with which to sign</param>
531531
/// <param name="accountKeyPath">The account key path (eg. [masterFP]/49'/0'/0')</param>
532532
/// <returns>This PSBT</returns>
533-
public PSBT SignAll(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath)
533+
public PSBT SignAll(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath)
534534
{
535535
if (accountKey == null)
536536
throw new ArgumentNullException(nameof(accountKey));
537-
if (accountHDScriptPubKey == null)
538-
throw new ArgumentNullException(nameof(accountHDScriptPubKey));
539-
accountHDScriptPubKey = accountHDScriptPubKey.AsHDKeyCache();
537+
accountHDScriptPubKey = accountHDScriptPubKey?.AsHDKeyCache();
540538
accountKey = accountKey.AsHDKeyCache();
541539
Money total = Money.Zero;
542540

@@ -1100,13 +1098,10 @@ public Money GetBalance(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey
11001098
/// <param name="accountKey">The account key that will be used to sign (ie. 49'/0'/0')</param>
11011099
/// <param name="accountKeyPath">The account key path</param>
11021100
/// <returns>Inputs with HD keys matching masterFingerprint and account key</returns>
1103-
public IEnumerable<PSBTCoin> CoinsFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
1101+
public IEnumerable<PSBTCoin> CoinsFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
11041102
{
11051103
if (accountKey == null)
11061104
throw new ArgumentNullException(nameof(accountKey));
1107-
if (accountHDScriptPubKey == null)
1108-
throw new ArgumentNullException(nameof(accountHDScriptPubKey));
1109-
accountHDScriptPubKey = accountHDScriptPubKey.AsHDKeyCache();
11101105
accountKey = accountKey.AsHDKeyCache();
11111106
return Inputs.CoinsFor(accountHDScriptPubKey, accountKey, accountKeyPath).OfType<PSBTCoin>().Concat(Outputs.CoinsFor(accountHDScriptPubKey, accountKey, accountKeyPath).OfType<PSBTCoin>());
11121107
}
@@ -1119,13 +1114,10 @@ public IEnumerable<PSBTCoin> CoinsFor(IHDScriptPubKey accountHDScriptPubKey, IHD
11191114
/// <param name="accountKey">The account key that will be used to sign (ie. 49'/0'/0')</param>
11201115
/// <param name="accountKeyPath">The account key path</param>
11211116
/// <returns>HD Keys matching master root key</returns>
1122-
public IEnumerable<PSBTHDKeyMatch> HDKeysFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
1117+
public IEnumerable<PSBTHDKeyMatch> HDKeysFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
11231118
{
11241119
if (accountKey == null)
11251120
throw new ArgumentNullException(nameof(accountKey));
1126-
if (accountHDScriptPubKey == null)
1127-
throw new ArgumentNullException(nameof(accountHDScriptPubKey));
1128-
accountHDScriptPubKey = accountHDScriptPubKey.AsHDKeyCache();
11291121
accountKey = accountKey.AsHDKeyCache();
11301122
return Inputs.HDKeysFor(accountHDScriptPubKey, accountKey, accountKeyPath).OfType<PSBTHDKeyMatch>().Concat(Outputs.HDKeysFor(accountHDScriptPubKey, accountKey, accountKeyPath));
11311123
}

NBitcoin/Network.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ internal static string CreateBase58(Base58Type type, byte[] bytes, Network netwo
203203
throw new ArgumentNullException(nameof(network));
204204
if (bytes == null)
205205
throw new ArgumentNullException(nameof(bytes));
206-
var versionBytes = network.GetVersionBytes(type, true);
206+
var versionBytes = network.GetVersionBytes(type, true)!;
207207
return network.NetworkStringParser.GetBase58CheckEncoder().EncodeData(versionBytes.Concat(bytes));
208208
}
209209

NBitcoin/Secp256k1/ECPrivKey.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ public bool TrySignECDSA(ReadOnlySpan<byte> msg32, INonceFunction? nonceFunction
711711
{
712712
return TrySignECDSA(msg32, nonceFunction, out _, out signature);
713713
}
714-
public bool TrySignECDSA(ReadOnlySpan<byte> msg32, INonceFunction? nonceFunction, out int recid, out SecpECDSASignature? signature)
714+
public bool TrySignECDSA(ReadOnlySpan<byte> msg32, INonceFunction? nonceFunction, out int recid, [MaybeNullWhen(false)] out SecpECDSASignature signature)
715715
{
716716
AssertNotDisposed();
717717
recid = 0;

NBitcoin/Utils.cs

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using NBitcoin.DataEncoders;
1+
#nullable enable
2+
using NBitcoin.DataEncoders;
23
using System;
34
using System.Collections.Generic;
45
using System.IO;
@@ -15,6 +16,7 @@
1516
using System.Runtime.InteropServices;
1617
#if !NOSOCKET
1718
using System.Net.Sockets;
19+
using System.Diagnostics.CodeAnalysis;
1820
#endif
1921
#if WINDOWS_UWP
2022
using System.Net.Sockets;
@@ -35,24 +37,24 @@ internal static Secp256k1.SecpECDSASignature Sign(this Secp256k1.ECPrivKey key,
3537
{
3638
Span<byte> hash = stackalloc byte[32];
3739
h.ToBytes(hash);
38-
byte[] extra_entropy = null;
39-
Secp256k1.RFC6979NonceFunction nonceFunction = null;
40+
byte[]? extra_entropy = null;
41+
Secp256k1.RFC6979NonceFunction? nonceFunction = null;
4042
Span<byte> vchSig = stackalloc byte[Secp256k1.SecpECDSASignature.MaxLength];
41-
Secp256k1.SecpECDSASignature sig;
43+
Secp256k1.SecpECDSASignature? sig;
4244
uint counter = 0;
43-
bool ret = key.TrySignECDSA(hash, null, out recid, out sig);
45+
key.TrySignECDSA(hash, null, out recid, out sig);
4446
// Grind for low R
45-
while (ret && sig.r.IsHigh && enforceLowR)
47+
while (sig is not null && sig.r.IsHigh && enforceLowR)
4648
{
4749
if (extra_entropy == null || nonceFunction == null)
4850
{
4951
extra_entropy = new byte[32];
5052
nonceFunction = new Secp256k1.RFC6979NonceFunction(extra_entropy);
5153
}
5254
Utils.ToBytes(++counter, true, extra_entropy.AsSpan());
53-
ret = key.TrySignECDSA(hash, nonceFunction, out recid, out sig);
55+
key.TrySignECDSA(hash, nonceFunction, out recid, out sig);
5456
}
55-
return sig;
57+
return sig!;
5658
}
5759
#endif
5860
/// <summary>
@@ -174,7 +176,7 @@ public static T ToNetwork<T>(this T obj, ChainName chainName) where T : IBitcoin
174176
{
175177
if (obj == null)
176178
throw new ArgumentNullException(nameof(obj));
177-
if (chainName == null)
179+
if (chainName is null)
178180
throw new ArgumentNullException(nameof(chainName));
179181
if (obj.Network.ChainName == chainName)
180182
return obj;
@@ -183,7 +185,7 @@ public static T ToNetwork<T>(this T obj, ChainName chainName) where T : IBitcoin
183185

184186
public static T ToNetwork<T>(this T obj, Network network) where T : IBitcoinString
185187
{
186-
if (network == null)
188+
if (network is null)
187189
throw new ArgumentNullException(nameof(network));
188190
if (obj == null)
189191
throw new ArgumentNullException(nameof(obj));
@@ -195,7 +197,7 @@ public static T ToNetwork<T>(this T obj, Network network) where T : IBitcoinStri
195197
if (b58.Type != Base58Type.COLORED_ADDRESS)
196198
{
197199

198-
byte[] version = network.GetVersionBytes(b58.Type, true);
200+
byte[] version = network.GetVersionBytes(b58.Type, true)!;
199201
var enc = network.NetworkStringParser.GetBase58CheckEncoder();
200202
var inner = enc.DecodeData(b58.ToString()).Skip(version.Length).ToArray();
201203
var newBase58 = enc.EncodeData(version.Concat(inner).ToArray());
@@ -211,10 +213,10 @@ public static T ToNetwork<T>(this T obj, Network network) where T : IBitcoinStri
211213
else if (obj is IBech32Data)
212214
{
213215
var b32 = (IBech32Data)obj;
214-
var encoder = b32.Network.GetBech32Encoder(b32.Type, true);
216+
var encoder = b32.Network.GetBech32Encoder(b32.Type, true)!;
215217
byte wit;
216218
var data = encoder.Decode(b32.ToString(), out wit);
217-
encoder = network.GetBech32Encoder(b32.Type, true);
219+
encoder = network.GetBech32Encoder(b32.Type, true)!;
218220
var str = encoder.Encode(wit, data);
219221
return (T)(object)Network.Parse<T>(str, network);
220222
}
@@ -257,12 +259,12 @@ public static int ReadBytes(this Stream stream, int count, out byte[] result)
257259
result = new byte[count];
258260
return stream.Read(result, 0, count);
259261
}
260-
public static IEnumerable<T> Resize<T>(this List<T> list, int count)
262+
public static IEnumerable<T?> Resize<T>(this List<T?> list, int count)
261263
{
262264
if (list.Count == count)
263265
return new T[0];
264266

265-
List<T> removed = new List<T>();
267+
var removed = new List<T?>();
266268

267269
for (int i = list.Count - 1; i + 1 > count; i--)
268270
{
@@ -419,9 +421,9 @@ public static void AddOrReplace<TKey, TValue>(this IDictionary<TKey, TValue> dic
419421
dico.Add(key, value);
420422
}
421423

422-
public static TValue TryGet<TKey, TValue>(this IDictionary<TKey, TValue> dictionary, TKey key)
424+
public static TValue? TryGet<TKey, TValue>(this IDictionary<TKey, TValue> dictionary, TKey key)
423425
{
424-
TValue value;
426+
TValue? value;
425427
dictionary.TryGetValue(key, out value);
426428
return value;
427429
}
@@ -568,10 +570,6 @@ private static void Write(MemoryStream ms, byte[] bytes)
568570
#if !HAS_SPAN
569571
internal static byte[] BigIntegerToBytes(BigInteger b, int numBytes)
570572
{
571-
if (b == null)
572-
{
573-
return null;
574-
}
575573
byte[] bytes = new byte[numBytes];
576574
byte[] biBytes = b.ToByteArray();
577575
int start = (biBytes.Length == numBytes + 1) ? 1 : 0;
@@ -653,7 +651,7 @@ public static DateTimeOffset UnixTimeToDateTime(long timestamp)
653651

654652
public static string ExceptionToString(Exception exception)
655653
{
656-
Exception ex = exception;
654+
Exception? ex = exception;
657655
StringBuilder stringBuilder = new StringBuilder(128);
658656
while (ex != null)
659657
{
@@ -670,7 +668,7 @@ public static string ExceptionToString(Exception exception)
670668
return stringBuilder.ToString();
671669
}
672670

673-
public static void Shuffle<T>(T[] arr, Random rand)
671+
public static void Shuffle<T>(T[] arr, Random? rand)
674672
{
675673
rand = rand ?? new Random();
676674
for (int i = 0; i < arr.Length; i++)
@@ -1000,7 +998,7 @@ public static ulong ToUInt64(byte[] value, bool littleEndian)
1000998

1001999
#if !NOSOCKET
10021000

1003-
public static bool TryParseEndpoint(string hostPort, int defaultPort, out EndPoint endpoint)
1001+
public static bool TryParseEndpoint(string hostPort, int defaultPort, [MaybeNullWhen(false)] out EndPoint endpoint)
10041002
{
10051003
if (hostPort == null)
10061004
throw new ArgumentNullException(nameof(hostPort));

0 commit comments

Comments
 (0)