Skip to content

Commit 25a3b9b

Browse files
committed
descriptors: Have GetPubKey fill origins directly
Instead of having ExpandHelper fill in the origins in the FlatSigningProvider output, have GetPubKey do it by itself. This reduces the extra variables needed in order to track and set origins in ExpandHelper. Also changes GetPubKey to return a std::optional<CPubKey> rather than using a bool and output parameters.
1 parent 6268bde commit 25a3b9b

File tree

1 file changed

+33
-39
lines changed

1 file changed

+33
-39
lines changed

src/script/descriptor.cpp

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -174,22 +174,20 @@ struct PubkeyProvider
174174
* Used by the Miniscript descriptors to check for duplicate keys in the script.
175175
*/
176176
bool operator<(PubkeyProvider& other) const {
177-
CPubKey a, b;
178-
SigningProvider dummy;
179-
KeyOriginInfo dummy_info;
177+
FlatSigningProvider dummy;
180178

181-
GetPubKey(0, dummy, a, dummy_info);
182-
other.GetPubKey(0, dummy, b, dummy_info);
179+
std::optional<CPubKey> a = GetPubKey(0, dummy, dummy);
180+
std::optional<CPubKey> b = other.GetPubKey(0, dummy, dummy);
183181

184182
return a < b;
185183
}
186184

187-
/** Derive a public key.
185+
/** Derive a public key and put it into out.
188186
* read_cache is the cache to read keys from (if not nullptr)
189187
* write_cache is the cache to write keys to (if not nullptr)
190188
* Caches are not exclusive but this is not tested. Currently we use them exclusively
191189
*/
192-
virtual bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const = 0;
190+
virtual std::optional<CPubKey> GetPubKey(int pos, const SigningProvider& arg, FlatSigningProvider& out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const = 0;
193191

194192
/** Whether this represent multiple public keys at different positions. */
195193
virtual bool IsRange() const = 0;
@@ -240,12 +238,15 @@ class OriginPubkeyProvider final : public PubkeyProvider
240238

241239
public:
242240
OriginPubkeyProvider(uint32_t exp_index, KeyOriginInfo info, std::unique_ptr<PubkeyProvider> provider, bool apostrophe) : PubkeyProvider(exp_index), m_origin(std::move(info)), m_provider(std::move(provider)), m_apostrophe(apostrophe) {}
243-
bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override
241+
std::optional<CPubKey> GetPubKey(int pos, const SigningProvider& arg, FlatSigningProvider& out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override
244242
{
245-
if (!m_provider->GetPubKey(pos, arg, key, info, read_cache, write_cache)) return false;
246-
std::copy(std::begin(m_origin.fingerprint), std::end(m_origin.fingerprint), info.fingerprint);
247-
info.path.insert(info.path.begin(), m_origin.path.begin(), m_origin.path.end());
248-
return true;
243+
std::optional<CPubKey> pub = m_provider->GetPubKey(pos, arg, out, read_cache, write_cache);
244+
if (!pub) return std::nullopt;
245+
auto& [pubkey, suborigin] = out.origins[pub->GetID()];
246+
Assert(pubkey == *pub); // m_provider must have a valid origin by this point.
247+
std::copy(std::begin(m_origin.fingerprint), std::end(m_origin.fingerprint), suborigin.fingerprint);
248+
suborigin.path.insert(suborigin.path.begin(), m_origin.path.begin(), m_origin.path.end());
249+
return pub;
249250
}
250251
bool IsRange() const override { return m_provider->IsRange(); }
251252
size_t GetSize() const override { return m_provider->GetSize(); }
@@ -298,13 +299,13 @@ class ConstPubkeyProvider final : public PubkeyProvider
298299

299300
public:
300301
ConstPubkeyProvider(uint32_t exp_index, const CPubKey& pubkey, bool xonly) : PubkeyProvider(exp_index), m_pubkey(pubkey), m_xonly(xonly) {}
301-
bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override
302+
std::optional<CPubKey> GetPubKey(int pos, const SigningProvider&, FlatSigningProvider& out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override
302303
{
303-
key = m_pubkey;
304-
info.path.clear();
304+
KeyOriginInfo info;
305305
CKeyID keyid = m_pubkey.GetID();
306306
std::copy(keyid.begin(), keyid.begin() + sizeof(info.fingerprint), info.fingerprint);
307-
return true;
307+
out.origins.emplace(keyid, std::make_pair(m_pubkey, info));
308+
return m_pubkey;
308309
}
309310
bool IsRange() const override { return false; }
310311
size_t GetSize() const override { return m_pubkey.size(); }
@@ -394,7 +395,7 @@ class BIP32PubkeyProvider final : public PubkeyProvider
394395
BIP32PubkeyProvider(uint32_t exp_index, const CExtPubKey& extkey, KeyPath path, DeriveType derive, bool apostrophe) : PubkeyProvider(exp_index), m_root_extkey(extkey), m_path(std::move(path)), m_derive(derive), m_apostrophe(apostrophe) {}
395396
bool IsRange() const override { return m_derive != DeriveType::NO; }
396397
size_t GetSize() const override { return 33; }
397-
bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key_out, KeyOriginInfo& final_info_out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override
398+
std::optional<CPubKey> GetPubKey(int pos, const SigningProvider& arg, FlatSigningProvider& out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override
398399
{
399400
KeyOriginInfo info;
400401
CKeyID keyid = m_root_extkey.pubkey.GetID();
@@ -410,16 +411,16 @@ class BIP32PubkeyProvider final : public PubkeyProvider
410411
bool der = true;
411412
if (read_cache) {
412413
if (!read_cache->GetCachedDerivedExtPubKey(m_expr_index, pos, final_extkey)) {
413-
if (m_derive == DeriveType::HARDENED) return false;
414+
if (m_derive == DeriveType::HARDENED) return std::nullopt;
414415
// Try to get the derivation parent
415-
if (!read_cache->GetCachedParentExtPubKey(m_expr_index, parent_extkey)) return false;
416+
if (!read_cache->GetCachedParentExtPubKey(m_expr_index, parent_extkey)) return std::nullopt;
416417
final_extkey = parent_extkey;
417418
if (m_derive == DeriveType::UNHARDENED) der = parent_extkey.Derive(final_extkey, pos);
418419
}
419420
} else if (IsHardened()) {
420421
CExtKey xprv;
421422
CExtKey lh_xprv;
422-
if (!GetDerivedExtKey(arg, xprv, lh_xprv)) return false;
423+
if (!GetDerivedExtKey(arg, xprv, lh_xprv)) return std::nullopt;
423424
parent_extkey = xprv.Neuter();
424425
if (m_derive == DeriveType::UNHARDENED) der = xprv.Derive(xprv, pos);
425426
if (m_derive == DeriveType::HARDENED) der = xprv.Derive(xprv, pos | 0x80000000UL);
@@ -429,16 +430,15 @@ class BIP32PubkeyProvider final : public PubkeyProvider
429430
}
430431
} else {
431432
for (auto entry : m_path) {
432-
if (!parent_extkey.Derive(parent_extkey, entry)) return false;
433+
if (!parent_extkey.Derive(parent_extkey, entry)) return std::nullopt;
433434
}
434435
final_extkey = parent_extkey;
435436
if (m_derive == DeriveType::UNHARDENED) der = parent_extkey.Derive(final_extkey, pos);
436437
assert(m_derive != DeriveType::HARDENED);
437438
}
438-
if (!der) return false;
439+
if (!der) return std::nullopt;
439440

440-
final_info_out = info;
441-
key_out = final_extkey.pubkey;
441+
out.origins.emplace(final_extkey.pubkey.GetID(), std::make_pair(final_extkey.pubkey, info));
442442

443443
if (write_cache) {
444444
// Only cache parent if there is any unhardened derivation
@@ -448,12 +448,12 @@ class BIP32PubkeyProvider final : public PubkeyProvider
448448
if (last_hardened_extkey.pubkey.IsValid()) {
449449
write_cache->CacheLastHardenedExtPubKey(m_expr_index, last_hardened_extkey);
450450
}
451-
} else if (final_info_out.path.size() > 0) {
451+
} else if (info.path.size() > 0) {
452452
write_cache->CacheDerivedExtPubKey(m_expr_index, pos, final_extkey);
453453
}
454454
}
455455

456-
return true;
456+
return final_extkey.pubkey;
457457
}
458458
std::string ToString(StringType type, bool normalized) const
459459
{
@@ -696,16 +696,17 @@ class DescriptorImpl : public Descriptor
696696
// NOLINTNEXTLINE(misc-no-recursion)
697697
bool ExpandHelper(int pos, const SigningProvider& arg, const DescriptorCache* read_cache, std::vector<CScript>& output_scripts, FlatSigningProvider& out, DescriptorCache* write_cache) const
698698
{
699-
std::vector<std::pair<CPubKey, KeyOriginInfo>> entries;
700-
entries.reserve(m_pubkey_args.size());
699+
FlatSigningProvider subprovider;
700+
std::vector<CPubKey> pubkeys;
701+
pubkeys.reserve(m_pubkey_args.size());
701702

702-
// Construct temporary data in `entries`, `subscripts`, and `subprovider` to avoid producing output in case of failure.
703+
// Construct temporary data in `pubkeys`, `subscripts`, and `subprovider` to avoid producing output in case of failure.
703704
for (const auto& p : m_pubkey_args) {
704-
entries.emplace_back();
705-
if (!p->GetPubKey(pos, arg, entries.back().first, entries.back().second, read_cache, write_cache)) return false;
705+
std::optional<CPubKey> pubkey = p->GetPubKey(pos, arg, subprovider, read_cache, write_cache);
706+
if (!pubkey) return false;
707+
pubkeys.push_back(pubkey.value());
706708
}
707709
std::vector<CScript> subscripts;
708-
FlatSigningProvider subprovider;
709710
for (const auto& subarg : m_subdescriptor_args) {
710711
std::vector<CScript> outscripts;
711712
if (!subarg->ExpandHelper(pos, arg, read_cache, outscripts, subprovider, write_cache)) return false;
@@ -714,13 +715,6 @@ class DescriptorImpl : public Descriptor
714715
}
715716
out.Merge(std::move(subprovider));
716717

717-
std::vector<CPubKey> pubkeys;
718-
pubkeys.reserve(entries.size());
719-
for (auto& entry : entries) {
720-
pubkeys.push_back(entry.first);
721-
out.origins.emplace(entry.first.GetID(), std::make_pair<CPubKey, KeyOriginInfo>(CPubKey(entry.first), std::move(entry.second)));
722-
}
723-
724718
output_scripts = MakeScripts(pubkeys, std::span{subscripts}, out);
725719
return true;
726720
}

0 commit comments

Comments
 (0)