diff --git a/NBitcoin/BIP174/PSBTCoin.cs b/NBitcoin/BIP174/PSBTCoin.cs index 040e9eca3..6a10df18e 100644 --- a/NBitcoin/BIP174/PSBTCoin.cs +++ b/NBitcoin/BIP174/PSBTCoin.cs @@ -111,45 +111,57 @@ public virtual void AddKeyPath(PubKey pubKey, RootedKeyPath rootedKeyPath) /// The account key that will be used to sign (ie. 49'/0'/0') /// The account key path /// HD Keys matching master root key - public IEnumerable HDKeysFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null) + public IEnumerable HDKeysFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null) { if (accountKey == null) throw new ArgumentNullException(nameof(accountKey)); - if (accountHDScriptPubKey == null) - throw new ArgumentNullException(nameof(accountHDScriptPubKey)); return HDKeysFor(accountHDScriptPubKey, accountKey, accountKeyPath, accountKey.GetPublicKey().GetHDFingerPrint()); } internal IEnumerable HDKeysFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath, HDFingerprint accountFingerprint) { + var coinScriptPubKey = this.GetTxOut()?.ScriptPubKey; + bool Match(KeyValuePair hdKey, HDFingerprint? expectedMasterFp, IHDKey accountKey, KeyPath addressPath) + { + if (expectedMasterFp is not null && + hdKey.Value.MasterFingerprint != expectedMasterFp.Value) + return false; + if (accountHDScriptPubKey is not null && + accountHDScriptPubKey.Derive(addressPath).ScriptPubKey != coinScriptPubKey) + return false; + var derived = accountKey.Derive(addressPath); + return hdKey.Key switch + { + PubKey pk => derived.GetPublicKey().Equals(pk), +#if HAS_SPAN + TaprootPubKey tpk => derived.GetPublicKey().GetTaprootFullPubKey() is var tfp && (tfp.OutputKey.Equals(tpk) || tfp.InternalKey.AsTaprootPubKey().Equals(tpk)), +#endif + _ => false + }; + } accountKey = accountKey.AsHDKeyCache(); accountHDScriptPubKey = accountHDScriptPubKey?.AsHDKeyCache(); - var coinScriptPubKey = this.GetTxOut()?.ScriptPubKey; foreach (var hdKey in EnumerateKeyPaths()) { bool matched = false; - + var canDeriveHardenedPath = (accountKey.CanDeriveHardenedPath() && (accountHDScriptPubKey == null || accountHDScriptPubKey.CanDeriveHardenedPath())); // The case where the fingerprint of the hdkey is exactly equal to the accountKey - if (hdKey.Value.MasterFingerprint == accountFingerprint) + if (!hdKey.Value.KeyPath.IsHardenedPath || canDeriveHardenedPath) { - // The fingerprint match, but we need to check the public keys, because fingerprint collision is easy to provoke - if (!hdKey.Value.KeyPath.IsHardenedPath || (accountKey.CanDeriveHardenedPath() && (accountHDScriptPubKey == null || accountHDScriptPubKey.CanDeriveHardenedPath()))) + if (Match(hdKey, accountFingerprint, accountKey, hdKey.Value.KeyPath)) { - if (accountHDScriptPubKey == null || accountHDScriptPubKey.Derive(hdKey.Value.KeyPath).ScriptPubKey == coinScriptPubKey) - { - yield return CreateHDKeyMatch(accountKey, hdKey.Value.KeyPath, hdKey); - matched = true; - } + yield return CreateHDKeyMatch(accountKey, hdKey.Value.KeyPath, hdKey); + matched = true; } } // The typical case where accountkey is based on an hardened derivation (eg. 49'/0'/0') - if (!matched && accountKeyPath?.MasterFingerprint is HDFingerprint mp && hdKey.Value.MasterFingerprint == mp) + if (!matched && accountKeyPath?.MasterFingerprint is HDFingerprint mp) { var addressPath = hdKey.Value.KeyPath.GetAddressKeyPath(); // The cases where addresses are generated on a non-hardened path below it (eg. 49'/0'/0'/0/1) if (addressPath.Indexes.Length != 0) { - if (accountHDScriptPubKey == null || accountHDScriptPubKey.Derive(addressPath).ScriptPubKey == coinScriptPubKey) + if (Match(hdKey, mp, accountKey, addressPath)) { yield return CreateHDKeyMatch(accountKey, addressPath, hdKey); matched = true; @@ -157,7 +169,7 @@ internal IEnumerable HDKeysFor(IHDScriptPubKey? accountHDScriptP } // in some cases addresses are generated on a hardened path below the account key (eg. 49'/0'/0'/0'/1') in which case we // need to brute force what the address key path is - else if (accountKey.CanDeriveHardenedPath() && (accountHDScriptPubKey == null || accountHDScriptPubKey.CanDeriveHardenedPath())) // We can only do this if we can derive hardened paths + else if (canDeriveHardenedPath) // We can only do this if we can derive hardened paths { int addressPathSize = 0; var hdKeyIndexes = hdKey.Value.KeyPath.Indexes; @@ -166,10 +178,9 @@ internal IEnumerable HDKeysFor(IHDScriptPubKey? accountHDScriptP var indexes = new uint[addressPathSize]; Array.Copy(hdKeyIndexes, hdKey.Value.KeyPath.Length - addressPathSize, indexes, 0, addressPathSize); addressPath = new KeyPath(indexes); - if (accountKey.Derive(addressPath).GetPublicKey().Equals(hdKey.Key)) + if (Match(hdKey, null, accountKey, addressPath)) { - if (accountHDScriptPubKey == null || accountHDScriptPubKey.Derive(addressPath).ScriptPubKey == coinScriptPubKey) - yield return CreateHDKeyMatch(accountKey, addressPath, hdKey); + yield return CreateHDKeyMatch(accountKey, addressPath, hdKey); matched = true; break; } diff --git a/NBitcoin/BIP174/PSBTCoinList.cs b/NBitcoin/BIP174/PSBTCoinList.cs index bfebd13c7..dd8bb91d9 100644 --- a/NBitcoin/BIP174/PSBTCoinList.cs +++ b/NBitcoin/BIP174/PSBTCoinList.cs @@ -106,7 +106,7 @@ public class PSBTCoinList : IReadOnlyList where T : PSBTCoin /// The account key that will be used to sign (ie. 49'/0'/0') /// The account key path /// Inputs with HD keys matching masterFingerprint and account key - public IEnumerable CoinsFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null) + public IEnumerable CoinsFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null) { return GetPSBTCoins(accountHDScriptPubKey, accountKey, accountKeyPath); } @@ -119,7 +119,7 @@ public IEnumerable CoinsFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey acc /// The account key that will be used to sign (ie. 49'/0'/0') /// The account key path /// HD Keys matching master root key - public IEnumerable> HDKeysFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null) + public IEnumerable> HDKeysFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null) { return GetHDKeys(accountHDScriptPubKey, accountKey, accountKeyPath); } @@ -135,7 +135,7 @@ public IEnumerable> HDKeysFor(IHDKey accountKey, RootedKeyPath return GetHDKeys(null, accountKey, accountKeyPath); } - internal IEnumerable GetPSBTCoins(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null) + internal IEnumerable GetPSBTCoins(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null) { return GetHDKeys(accountHDScriptPubKey, accountKey, accountKeyPath) .Select(c => c.Coin) @@ -147,7 +147,6 @@ internal IEnumerable> GetHDKeys(IHDScriptPubKey? hdScriptPubKe if (accountKey == null) throw new ArgumentNullException(nameof(accountKey)); accountKey = accountKey.AsHDKeyCache(); - hdScriptPubKey = hdScriptPubKey?.AsHDKeyCache(); var accountFingerprint = accountKey.GetPublicKey().GetHDFingerPrint(); foreach (var c in this) { diff --git a/NBitcoin/BIP174/PSBTInput.cs b/NBitcoin/BIP174/PSBTInput.cs index 8a5ef41df..18c4721fa 100644 --- a/NBitcoin/BIP174/PSBTInput.cs +++ b/NBitcoin/BIP174/PSBTInput.cs @@ -504,16 +504,13 @@ public void TrySign(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, Ro { TrySign(accountHDScriptPubKey, accountKey, accountKeyPath, null); } - internal void TrySign(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath, SigningOptions? signingOptions) + internal void TrySign(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath, SigningOptions? signingOptions) { if (accountKey == null) throw new ArgumentNullException(nameof(accountKey)); - if (accountHDScriptPubKey == null) - throw new ArgumentNullException(nameof(accountHDScriptPubKey)); if (IsFinalized()) return; var cache = accountKey.AsHDKeyCache(); - accountHDScriptPubKey = accountHDScriptPubKey.AsHDKeyCache(); foreach (var hdk in this.HDKeysFor(accountHDScriptPubKey, cache, accountKeyPath)) { if (((HDKeyCache)cache.Derive(hdk.AddressKeyPath)).Inner is ISecret k) diff --git a/NBitcoin/BIP174/PartiallySignedTransaction.cs b/NBitcoin/BIP174/PartiallySignedTransaction.cs index a2b5fb2ce..4e9a07518 100644 --- a/NBitcoin/BIP174/PartiallySignedTransaction.cs +++ b/NBitcoin/BIP174/PartiallySignedTransaction.cs @@ -530,13 +530,11 @@ public PSBT SignAll(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey) /// The account key with which to sign /// The account key path (eg. [masterFP]/49'/0'/0') /// This PSBT - public PSBT SignAll(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath) + public PSBT SignAll(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath) { if (accountKey == null) throw new ArgumentNullException(nameof(accountKey)); - if (accountHDScriptPubKey == null) - throw new ArgumentNullException(nameof(accountHDScriptPubKey)); - accountHDScriptPubKey = accountHDScriptPubKey.AsHDKeyCache(); + accountHDScriptPubKey = accountHDScriptPubKey?.AsHDKeyCache(); accountKey = accountKey.AsHDKeyCache(); Money total = Money.Zero; @@ -1100,13 +1098,10 @@ public Money GetBalance(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey /// The account key that will be used to sign (ie. 49'/0'/0') /// The account key path /// Inputs with HD keys matching masterFingerprint and account key - public IEnumerable CoinsFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null) + public IEnumerable CoinsFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null) { if (accountKey == null) throw new ArgumentNullException(nameof(accountKey)); - if (accountHDScriptPubKey == null) - throw new ArgumentNullException(nameof(accountHDScriptPubKey)); - accountHDScriptPubKey = accountHDScriptPubKey.AsHDKeyCache(); accountKey = accountKey.AsHDKeyCache(); return Inputs.CoinsFor(accountHDScriptPubKey, accountKey, accountKeyPath).OfType().Concat(Outputs.CoinsFor(accountHDScriptPubKey, accountKey, accountKeyPath).OfType()); } @@ -1119,13 +1114,10 @@ public IEnumerable CoinsFor(IHDScriptPubKey accountHDScriptPubKey, IHD /// The account key that will be used to sign (ie. 49'/0'/0') /// The account key path /// HD Keys matching master root key - public IEnumerable HDKeysFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null) + public IEnumerable HDKeysFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null) { if (accountKey == null) throw new ArgumentNullException(nameof(accountKey)); - if (accountHDScriptPubKey == null) - throw new ArgumentNullException(nameof(accountHDScriptPubKey)); - accountHDScriptPubKey = accountHDScriptPubKey.AsHDKeyCache(); accountKey = accountKey.AsHDKeyCache(); return Inputs.HDKeysFor(accountHDScriptPubKey, accountKey, accountKeyPath).OfType().Concat(Outputs.HDKeysFor(accountHDScriptPubKey, accountKey, accountKeyPath)); } diff --git a/NBitcoin/Network.cs b/NBitcoin/Network.cs index 06aa0e8dc..32220517d 100644 --- a/NBitcoin/Network.cs +++ b/NBitcoin/Network.cs @@ -203,7 +203,7 @@ internal static string CreateBase58(Base58Type type, byte[] bytes, Network netwo throw new ArgumentNullException(nameof(network)); if (bytes == null) throw new ArgumentNullException(nameof(bytes)); - var versionBytes = network.GetVersionBytes(type, true); + var versionBytes = network.GetVersionBytes(type, true)!; return network.NetworkStringParser.GetBase58CheckEncoder().EncodeData(versionBytes.Concat(bytes)); } diff --git a/NBitcoin/Secp256k1/ECPrivKey.cs b/NBitcoin/Secp256k1/ECPrivKey.cs index 7285ee9e6..5129f3613 100644 --- a/NBitcoin/Secp256k1/ECPrivKey.cs +++ b/NBitcoin/Secp256k1/ECPrivKey.cs @@ -711,7 +711,7 @@ public bool TrySignECDSA(ReadOnlySpan msg32, INonceFunction? nonceFunction { return TrySignECDSA(msg32, nonceFunction, out _, out signature); } - public bool TrySignECDSA(ReadOnlySpan msg32, INonceFunction? nonceFunction, out int recid, out SecpECDSASignature? signature) + public bool TrySignECDSA(ReadOnlySpan msg32, INonceFunction? nonceFunction, out int recid, [MaybeNullWhen(false)] out SecpECDSASignature signature) { AssertNotDisposed(); recid = 0; diff --git a/NBitcoin/Utils.cs b/NBitcoin/Utils.cs index 77b622e68..a5a2d4ccd 100644 --- a/NBitcoin/Utils.cs +++ b/NBitcoin/Utils.cs @@ -1,4 +1,5 @@ -using NBitcoin.DataEncoders; +#nullable enable +using NBitcoin.DataEncoders; using System; using System.Collections.Generic; using System.IO; @@ -15,6 +16,7 @@ using System.Runtime.InteropServices; #if !NOSOCKET using System.Net.Sockets; +using System.Diagnostics.CodeAnalysis; #endif #if WINDOWS_UWP using System.Net.Sockets; @@ -35,14 +37,14 @@ internal static Secp256k1.SecpECDSASignature Sign(this Secp256k1.ECPrivKey key, { Span hash = stackalloc byte[32]; h.ToBytes(hash); - byte[] extra_entropy = null; - Secp256k1.RFC6979NonceFunction nonceFunction = null; + byte[]? extra_entropy = null; + Secp256k1.RFC6979NonceFunction? nonceFunction = null; Span vchSig = stackalloc byte[Secp256k1.SecpECDSASignature.MaxLength]; - Secp256k1.SecpECDSASignature sig; + Secp256k1.SecpECDSASignature? sig; uint counter = 0; - bool ret = key.TrySignECDSA(hash, null, out recid, out sig); + key.TrySignECDSA(hash, null, out recid, out sig); // Grind for low R - while (ret && sig.r.IsHigh && enforceLowR) + while (sig is not null && sig.r.IsHigh && enforceLowR) { if (extra_entropy == null || nonceFunction == null) { @@ -50,9 +52,9 @@ internal static Secp256k1.SecpECDSASignature Sign(this Secp256k1.ECPrivKey key, nonceFunction = new Secp256k1.RFC6979NonceFunction(extra_entropy); } Utils.ToBytes(++counter, true, extra_entropy.AsSpan()); - ret = key.TrySignECDSA(hash, nonceFunction, out recid, out sig); + key.TrySignECDSA(hash, nonceFunction, out recid, out sig); } - return sig; + return sig!; } #endif /// @@ -174,7 +176,7 @@ public static T ToNetwork(this T obj, ChainName chainName) where T : IBitcoin { if (obj == null) throw new ArgumentNullException(nameof(obj)); - if (chainName == null) + if (chainName is null) throw new ArgumentNullException(nameof(chainName)); if (obj.Network.ChainName == chainName) return obj; @@ -183,7 +185,7 @@ public static T ToNetwork(this T obj, ChainName chainName) where T : IBitcoin public static T ToNetwork(this T obj, Network network) where T : IBitcoinString { - if (network == null) + if (network is null) throw new ArgumentNullException(nameof(network)); if (obj == null) throw new ArgumentNullException(nameof(obj)); @@ -195,7 +197,7 @@ public static T ToNetwork(this T obj, Network network) where T : IBitcoinStri if (b58.Type != Base58Type.COLORED_ADDRESS) { - byte[] version = network.GetVersionBytes(b58.Type, true); + byte[] version = network.GetVersionBytes(b58.Type, true)!; var enc = network.NetworkStringParser.GetBase58CheckEncoder(); var inner = enc.DecodeData(b58.ToString()).Skip(version.Length).ToArray(); var newBase58 = enc.EncodeData(version.Concat(inner).ToArray()); @@ -211,10 +213,10 @@ public static T ToNetwork(this T obj, Network network) where T : IBitcoinStri else if (obj is IBech32Data) { var b32 = (IBech32Data)obj; - var encoder = b32.Network.GetBech32Encoder(b32.Type, true); + var encoder = b32.Network.GetBech32Encoder(b32.Type, true)!; byte wit; var data = encoder.Decode(b32.ToString(), out wit); - encoder = network.GetBech32Encoder(b32.Type, true); + encoder = network.GetBech32Encoder(b32.Type, true)!; var str = encoder.Encode(wit, data); return (T)(object)Network.Parse(str, network); } @@ -257,12 +259,12 @@ public static int ReadBytes(this Stream stream, int count, out byte[] result) result = new byte[count]; return stream.Read(result, 0, count); } - public static IEnumerable Resize(this List list, int count) + public static IEnumerable Resize(this List list, int count) { if (list.Count == count) return new T[0]; - List removed = new List(); + var removed = new List(); for (int i = list.Count - 1; i + 1 > count; i--) { @@ -419,9 +421,9 @@ public static void AddOrReplace(this IDictionary dic dico.Add(key, value); } - public static TValue TryGet(this IDictionary dictionary, TKey key) + public static TValue? TryGet(this IDictionary dictionary, TKey key) { - TValue value; + TValue? value; dictionary.TryGetValue(key, out value); return value; } @@ -568,10 +570,6 @@ private static void Write(MemoryStream ms, byte[] bytes) #if !HAS_SPAN internal static byte[] BigIntegerToBytes(BigInteger b, int numBytes) { - if (b == null) - { - return null; - } byte[] bytes = new byte[numBytes]; byte[] biBytes = b.ToByteArray(); int start = (biBytes.Length == numBytes + 1) ? 1 : 0; @@ -653,7 +651,7 @@ public static DateTimeOffset UnixTimeToDateTime(long timestamp) public static string ExceptionToString(Exception exception) { - Exception ex = exception; + Exception? ex = exception; StringBuilder stringBuilder = new StringBuilder(128); while (ex != null) { @@ -670,7 +668,7 @@ public static string ExceptionToString(Exception exception) return stringBuilder.ToString(); } - public static void Shuffle(T[] arr, Random rand) + public static void Shuffle(T[] arr, Random? rand) { rand = rand ?? new Random(); for (int i = 0; i < arr.Length; i++) @@ -1000,7 +998,7 @@ public static ulong ToUInt64(byte[] value, bool littleEndian) #if !NOSOCKET - public static bool TryParseEndpoint(string hostPort, int defaultPort, out EndPoint endpoint) + public static bool TryParseEndpoint(string hostPort, int defaultPort, [MaybeNullWhen(false)] out EndPoint endpoint) { if (hostPort == null) throw new ArgumentNullException(nameof(hostPort));