Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions NBitcoin/BIP174/PSBTCoin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,53 +111,65 @@ public virtual void AddKeyPath(PubKey pubKey, RootedKeyPath rootedKeyPath)
/// <param name="accountKey">The account key that will be used to sign (ie. 49'/0'/0')</param>
/// <param name="accountKeyPath">The account key path</param>
/// <returns>HD Keys matching master root key</returns>
public IEnumerable<PSBTHDKeyMatch> HDKeysFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
public IEnumerable<PSBTHDKeyMatch> HDKeysFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
{
if (accountKey == null)
throw new ArgumentNullException(nameof(accountKey));
if (accountHDScriptPubKey == null)
throw new ArgumentNullException(nameof(accountHDScriptPubKey));
return HDKeysFor(accountHDScriptPubKey, accountKey, accountKeyPath, accountKey.GetPublicKey().GetHDFingerPrint());
}
internal IEnumerable<PSBTHDKeyMatch> HDKeysFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath, HDFingerprint accountFingerprint)
{
var coinScriptPubKey = this.GetTxOut()?.ScriptPubKey;
bool Match(KeyValuePair<IPubKey, RootedKeyPath> hdKey, HDFingerprint? expectedMasterFp, IHDKey accountKey, KeyPath addressPath)
{
if (expectedMasterFp is not null &&
hdKey.Value.MasterFingerprint != expectedMasterFp.Value)
return false;
if (accountHDScriptPubKey is not null &&
accountHDScriptPubKey.Derive(addressPath).ScriptPubKey != coinScriptPubKey)
return false;
var derived = accountKey.Derive(addressPath);
return hdKey.Key switch
{
PubKey pk => derived.GetPublicKey().Equals(pk),
#if HAS_SPAN
TaprootPubKey tpk => derived.GetPublicKey().GetTaprootFullPubKey() is var tfp && (tfp.OutputKey.Equals(tpk) || tfp.InternalKey.AsTaprootPubKey().Equals(tpk)),
#endif
_ => false
};
}
accountKey = accountKey.AsHDKeyCache();
accountHDScriptPubKey = accountHDScriptPubKey?.AsHDKeyCache();
var coinScriptPubKey = this.GetTxOut()?.ScriptPubKey;
foreach (var hdKey in EnumerateKeyPaths())
{
bool matched = false;

var canDeriveHardenedPath = (accountKey.CanDeriveHardenedPath() && (accountHDScriptPubKey == null || accountHDScriptPubKey.CanDeriveHardenedPath()));
// The case where the fingerprint of the hdkey is exactly equal to the accountKey
if (hdKey.Value.MasterFingerprint == accountFingerprint)
if (!hdKey.Value.KeyPath.IsHardenedPath || canDeriveHardenedPath)
{
// The fingerprint match, but we need to check the public keys, because fingerprint collision is easy to provoke
if (!hdKey.Value.KeyPath.IsHardenedPath || (accountKey.CanDeriveHardenedPath() && (accountHDScriptPubKey == null || accountHDScriptPubKey.CanDeriveHardenedPath())))
if (Match(hdKey, accountFingerprint, accountKey, hdKey.Value.KeyPath))
{
if (accountHDScriptPubKey == null || accountHDScriptPubKey.Derive(hdKey.Value.KeyPath).ScriptPubKey == coinScriptPubKey)
{
yield return CreateHDKeyMatch(accountKey, hdKey.Value.KeyPath, hdKey);
matched = true;
}
yield return CreateHDKeyMatch(accountKey, hdKey.Value.KeyPath, hdKey);
matched = true;
}
}

// The typical case where accountkey is based on an hardened derivation (eg. 49'/0'/0')
if (!matched && accountKeyPath?.MasterFingerprint is HDFingerprint mp && hdKey.Value.MasterFingerprint == mp)
if (!matched && accountKeyPath?.MasterFingerprint is HDFingerprint mp)
{
var addressPath = hdKey.Value.KeyPath.GetAddressKeyPath();
// The cases where addresses are generated on a non-hardened path below it (eg. 49'/0'/0'/0/1)
if (addressPath.Indexes.Length != 0)
{
if (accountHDScriptPubKey == null || accountHDScriptPubKey.Derive(addressPath).ScriptPubKey == coinScriptPubKey)
if (Match(hdKey, mp, accountKey, addressPath))
{
yield return CreateHDKeyMatch(accountKey, addressPath, hdKey);
matched = true;
}
}
// in some cases addresses are generated on a hardened path below the account key (eg. 49'/0'/0'/0'/1') in which case we
// need to brute force what the address key path is
else if (accountKey.CanDeriveHardenedPath() && (accountHDScriptPubKey == null || accountHDScriptPubKey.CanDeriveHardenedPath())) // We can only do this if we can derive hardened paths
else if (canDeriveHardenedPath) // We can only do this if we can derive hardened paths
{
int addressPathSize = 0;
var hdKeyIndexes = hdKey.Value.KeyPath.Indexes;
Expand All @@ -166,10 +178,9 @@ internal IEnumerable<PSBTHDKeyMatch> HDKeysFor(IHDScriptPubKey? accountHDScriptP
var indexes = new uint[addressPathSize];
Array.Copy(hdKeyIndexes, hdKey.Value.KeyPath.Length - addressPathSize, indexes, 0, addressPathSize);
addressPath = new KeyPath(indexes);
if (accountKey.Derive(addressPath).GetPublicKey().Equals(hdKey.Key))
if (Match(hdKey, null, accountKey, addressPath))
{
if (accountHDScriptPubKey == null || accountHDScriptPubKey.Derive(addressPath).ScriptPubKey == coinScriptPubKey)
yield return CreateHDKeyMatch(accountKey, addressPath, hdKey);
yield return CreateHDKeyMatch(accountKey, addressPath, hdKey);
matched = true;
break;
}
Expand Down
7 changes: 3 additions & 4 deletions NBitcoin/BIP174/PSBTCoinList.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public class PSBTCoinList<T> : IReadOnlyList<T> where T : PSBTCoin
/// <param name="accountKey">The account key that will be used to sign (ie. 49'/0'/0')</param>
/// <param name="accountKeyPath">The account key path</param>
/// <returns>Inputs with HD keys matching masterFingerprint and account key</returns>
public IEnumerable<T> CoinsFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
public IEnumerable<T> CoinsFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
{
return GetPSBTCoins(accountHDScriptPubKey, accountKey, accountKeyPath);
}
Expand All @@ -119,7 +119,7 @@ public IEnumerable<T> CoinsFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey acc
/// <param name="accountKey">The account key that will be used to sign (ie. 49'/0'/0')</param>
/// <param name="accountKeyPath">The account key path</param>
/// <returns>HD Keys matching master root key</returns>
public IEnumerable<PSBTHDKeyMatch<T>> HDKeysFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
public IEnumerable<PSBTHDKeyMatch<T>> HDKeysFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
{
return GetHDKeys(accountHDScriptPubKey, accountKey, accountKeyPath);
}
Expand All @@ -135,7 +135,7 @@ public IEnumerable<PSBTHDKeyMatch<T>> HDKeysFor(IHDKey accountKey, RootedKeyPath
return GetHDKeys(null, accountKey, accountKeyPath);
}

internal IEnumerable<T> GetPSBTCoins(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
internal IEnumerable<T> GetPSBTCoins(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
{
return GetHDKeys(accountHDScriptPubKey, accountKey, accountKeyPath)
.Select(c => c.Coin)
Expand All @@ -147,7 +147,6 @@ internal IEnumerable<PSBTHDKeyMatch<T>> GetHDKeys(IHDScriptPubKey? hdScriptPubKe
if (accountKey == null)
throw new ArgumentNullException(nameof(accountKey));
accountKey = accountKey.AsHDKeyCache();
hdScriptPubKey = hdScriptPubKey?.AsHDKeyCache();
var accountFingerprint = accountKey.GetPublicKey().GetHDFingerPrint();
foreach (var c in this)
{
Expand Down
5 changes: 1 addition & 4 deletions NBitcoin/BIP174/PSBTInput.cs
Original file line number Diff line number Diff line change
Expand Up @@ -504,16 +504,13 @@ public void TrySign(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, Ro
{
TrySign(accountHDScriptPubKey, accountKey, accountKeyPath, null);
}
internal void TrySign(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath, SigningOptions? signingOptions)
internal void TrySign(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath, SigningOptions? signingOptions)
{
if (accountKey == null)
throw new ArgumentNullException(nameof(accountKey));
if (accountHDScriptPubKey == null)
throw new ArgumentNullException(nameof(accountHDScriptPubKey));
if (IsFinalized())
return;
var cache = accountKey.AsHDKeyCache();
accountHDScriptPubKey = accountHDScriptPubKey.AsHDKeyCache();
foreach (var hdk in this.HDKeysFor(accountHDScriptPubKey, cache, accountKeyPath))
{
if (((HDKeyCache)cache.Derive(hdk.AddressKeyPath)).Inner is ISecret k)
Expand Down
16 changes: 4 additions & 12 deletions NBitcoin/BIP174/PartiallySignedTransaction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -530,13 +530,11 @@ public PSBT SignAll(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey)
/// <param name="accountKey">The account key with which to sign</param>
/// <param name="accountKeyPath">The account key path (eg. [masterFP]/49'/0'/0')</param>
/// <returns>This PSBT</returns>
public PSBT SignAll(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath)
public PSBT SignAll(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath)
{
if (accountKey == null)
throw new ArgumentNullException(nameof(accountKey));
if (accountHDScriptPubKey == null)
throw new ArgumentNullException(nameof(accountHDScriptPubKey));
accountHDScriptPubKey = accountHDScriptPubKey.AsHDKeyCache();
accountHDScriptPubKey = accountHDScriptPubKey?.AsHDKeyCache();
accountKey = accountKey.AsHDKeyCache();
Money total = Money.Zero;

Expand Down Expand Up @@ -1100,13 +1098,10 @@ public Money GetBalance(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey
/// <param name="accountKey">The account key that will be used to sign (ie. 49'/0'/0')</param>
/// <param name="accountKeyPath">The account key path</param>
/// <returns>Inputs with HD keys matching masterFingerprint and account key</returns>
public IEnumerable<PSBTCoin> CoinsFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
public IEnumerable<PSBTCoin> CoinsFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
{
if (accountKey == null)
throw new ArgumentNullException(nameof(accountKey));
if (accountHDScriptPubKey == null)
throw new ArgumentNullException(nameof(accountHDScriptPubKey));
accountHDScriptPubKey = accountHDScriptPubKey.AsHDKeyCache();
accountKey = accountKey.AsHDKeyCache();
return Inputs.CoinsFor(accountHDScriptPubKey, accountKey, accountKeyPath).OfType<PSBTCoin>().Concat(Outputs.CoinsFor(accountHDScriptPubKey, accountKey, accountKeyPath).OfType<PSBTCoin>());
}
Expand All @@ -1119,13 +1114,10 @@ public IEnumerable<PSBTCoin> CoinsFor(IHDScriptPubKey accountHDScriptPubKey, IHD
/// <param name="accountKey">The account key that will be used to sign (ie. 49'/0'/0')</param>
/// <param name="accountKeyPath">The account key path</param>
/// <returns>HD Keys matching master root key</returns>
public IEnumerable<PSBTHDKeyMatch> HDKeysFor(IHDScriptPubKey accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
public IEnumerable<PSBTHDKeyMatch> HDKeysFor(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, RootedKeyPath? accountKeyPath = null)
{
if (accountKey == null)
throw new ArgumentNullException(nameof(accountKey));
if (accountHDScriptPubKey == null)
throw new ArgumentNullException(nameof(accountHDScriptPubKey));
accountHDScriptPubKey = accountHDScriptPubKey.AsHDKeyCache();
accountKey = accountKey.AsHDKeyCache();
return Inputs.HDKeysFor(accountHDScriptPubKey, accountKey, accountKeyPath).OfType<PSBTHDKeyMatch>().Concat(Outputs.HDKeysFor(accountHDScriptPubKey, accountKey, accountKeyPath));
}
Expand Down
2 changes: 1 addition & 1 deletion NBitcoin/Network.cs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ internal static string CreateBase58(Base58Type type, byte[] bytes, Network netwo
throw new ArgumentNullException(nameof(network));
if (bytes == null)
throw new ArgumentNullException(nameof(bytes));
var versionBytes = network.GetVersionBytes(type, true);
var versionBytes = network.GetVersionBytes(type, true)!;
return network.NetworkStringParser.GetBase58CheckEncoder().EncodeData(versionBytes.Concat(bytes));
}

Expand Down
2 changes: 1 addition & 1 deletion NBitcoin/Secp256k1/ECPrivKey.cs
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ public bool TrySignECDSA(ReadOnlySpan<byte> msg32, INonceFunction? nonceFunction
{
return TrySignECDSA(msg32, nonceFunction, out _, out signature);
}
public bool TrySignECDSA(ReadOnlySpan<byte> msg32, INonceFunction? nonceFunction, out int recid, out SecpECDSASignature? signature)
public bool TrySignECDSA(ReadOnlySpan<byte> msg32, INonceFunction? nonceFunction, out int recid, [MaybeNullWhen(false)] out SecpECDSASignature signature)
{
AssertNotDisposed();
recid = 0;
Expand Down
46 changes: 22 additions & 24 deletions NBitcoin/Utils.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using NBitcoin.DataEncoders;
#nullable enable
using NBitcoin.DataEncoders;
using System;
using System.Collections.Generic;
using System.IO;
Expand All @@ -15,6 +16,7 @@
using System.Runtime.InteropServices;
#if !NOSOCKET
using System.Net.Sockets;
using System.Diagnostics.CodeAnalysis;
#endif
#if WINDOWS_UWP
using System.Net.Sockets;
Expand All @@ -35,24 +37,24 @@ internal static Secp256k1.SecpECDSASignature Sign(this Secp256k1.ECPrivKey key,
{
Span<byte> hash = stackalloc byte[32];
h.ToBytes(hash);
byte[] extra_entropy = null;
Secp256k1.RFC6979NonceFunction nonceFunction = null;
byte[]? extra_entropy = null;
Secp256k1.RFC6979NonceFunction? nonceFunction = null;
Span<byte> vchSig = stackalloc byte[Secp256k1.SecpECDSASignature.MaxLength];
Secp256k1.SecpECDSASignature sig;
Secp256k1.SecpECDSASignature? sig;
uint counter = 0;
bool ret = key.TrySignECDSA(hash, null, out recid, out sig);
key.TrySignECDSA(hash, null, out recid, out sig);
// Grind for low R
while (ret && sig.r.IsHigh && enforceLowR)
while (sig is not null && sig.r.IsHigh && enforceLowR)
{
if (extra_entropy == null || nonceFunction == null)
{
extra_entropy = new byte[32];
nonceFunction = new Secp256k1.RFC6979NonceFunction(extra_entropy);
}
Utils.ToBytes(++counter, true, extra_entropy.AsSpan());
ret = key.TrySignECDSA(hash, nonceFunction, out recid, out sig);
key.TrySignECDSA(hash, nonceFunction, out recid, out sig);
}
return sig;
return sig!;
}
#endif
/// <summary>
Expand Down Expand Up @@ -174,7 +176,7 @@ public static T ToNetwork<T>(this T obj, ChainName chainName) where T : IBitcoin
{
if (obj == null)
throw new ArgumentNullException(nameof(obj));
if (chainName == null)
if (chainName is null)
throw new ArgumentNullException(nameof(chainName));
if (obj.Network.ChainName == chainName)
return obj;
Expand All @@ -183,7 +185,7 @@ public static T ToNetwork<T>(this T obj, ChainName chainName) where T : IBitcoin

public static T ToNetwork<T>(this T obj, Network network) where T : IBitcoinString
{
if (network == null)
if (network is null)
throw new ArgumentNullException(nameof(network));
if (obj == null)
throw new ArgumentNullException(nameof(obj));
Expand All @@ -195,7 +197,7 @@ public static T ToNetwork<T>(this T obj, Network network) where T : IBitcoinStri
if (b58.Type != Base58Type.COLORED_ADDRESS)
{

byte[] version = network.GetVersionBytes(b58.Type, true);
byte[] version = network.GetVersionBytes(b58.Type, true)!;
var enc = network.NetworkStringParser.GetBase58CheckEncoder();
var inner = enc.DecodeData(b58.ToString()).Skip(version.Length).ToArray();
var newBase58 = enc.EncodeData(version.Concat(inner).ToArray());
Expand All @@ -211,10 +213,10 @@ public static T ToNetwork<T>(this T obj, Network network) where T : IBitcoinStri
else if (obj is IBech32Data)
{
var b32 = (IBech32Data)obj;
var encoder = b32.Network.GetBech32Encoder(b32.Type, true);
var encoder = b32.Network.GetBech32Encoder(b32.Type, true)!;
byte wit;
var data = encoder.Decode(b32.ToString(), out wit);
encoder = network.GetBech32Encoder(b32.Type, true);
encoder = network.GetBech32Encoder(b32.Type, true)!;
var str = encoder.Encode(wit, data);
return (T)(object)Network.Parse<T>(str, network);
}
Expand Down Expand Up @@ -257,12 +259,12 @@ public static int ReadBytes(this Stream stream, int count, out byte[] result)
result = new byte[count];
return stream.Read(result, 0, count);
}
public static IEnumerable<T> Resize<T>(this List<T> list, int count)
public static IEnumerable<T?> Resize<T>(this List<T?> list, int count)
{
if (list.Count == count)
return new T[0];

List<T> removed = new List<T>();
var removed = new List<T?>();

for (int i = list.Count - 1; i + 1 > count; i--)
{
Expand Down Expand Up @@ -419,9 +421,9 @@ public static void AddOrReplace<TKey, TValue>(this IDictionary<TKey, TValue> dic
dico.Add(key, value);
}

public static TValue TryGet<TKey, TValue>(this IDictionary<TKey, TValue> dictionary, TKey key)
public static TValue? TryGet<TKey, TValue>(this IDictionary<TKey, TValue> dictionary, TKey key)
{
TValue value;
TValue? value;
dictionary.TryGetValue(key, out value);
return value;
}
Expand Down Expand Up @@ -568,10 +570,6 @@ private static void Write(MemoryStream ms, byte[] bytes)
#if !HAS_SPAN
internal static byte[] BigIntegerToBytes(BigInteger b, int numBytes)
{
if (b == null)
{
return null;
}
byte[] bytes = new byte[numBytes];
byte[] biBytes = b.ToByteArray();
int start = (biBytes.Length == numBytes + 1) ? 1 : 0;
Expand Down Expand Up @@ -653,7 +651,7 @@ public static DateTimeOffset UnixTimeToDateTime(long timestamp)

public static string ExceptionToString(Exception exception)
{
Exception ex = exception;
Exception? ex = exception;
StringBuilder stringBuilder = new StringBuilder(128);
while (ex != null)
{
Expand All @@ -670,7 +668,7 @@ public static string ExceptionToString(Exception exception)
return stringBuilder.ToString();
}

public static void Shuffle<T>(T[] arr, Random rand)
public static void Shuffle<T>(T[] arr, Random? rand)
{
rand = rand ?? new Random();
for (int i = 0; i < arr.Length; i++)
Expand Down Expand Up @@ -1000,7 +998,7 @@ public static ulong ToUInt64(byte[] value, bool littleEndian)

#if !NOSOCKET

public static bool TryParseEndpoint(string hostPort, int defaultPort, out EndPoint endpoint)
public static bool TryParseEndpoint(string hostPort, int defaultPort, [MaybeNullWhen(false)] out EndPoint endpoint)
{
if (hostPort == null)
throw new ArgumentNullException(nameof(hostPort));
Expand Down
Loading