Skip to content

Commit d637a9b

Browse files
committed
Taproot descriptor inference
1 parent c7388e5 commit d637a9b

File tree

4 files changed

+237
-9
lines changed

4 files changed

+237
-9
lines changed

src/script/descriptor.cpp

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ class ConstPubkeyProvider final : public PubkeyProvider
244244
bool m_xonly;
245245

246246
public:
247-
ConstPubkeyProvider(uint32_t exp_index, const CPubKey& pubkey, bool xonly = false) : PubkeyProvider(exp_index), m_pubkey(pubkey), m_xonly(xonly) {}
247+
ConstPubkeyProvider(uint32_t exp_index, const CPubKey& pubkey, bool xonly) : PubkeyProvider(exp_index), m_pubkey(pubkey), m_xonly(xonly) {}
248248
bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) override
249249
{
250250
key = m_pubkey;
@@ -931,7 +931,7 @@ std::unique_ptr<PubkeyProvider> ParsePubkeyInner(uint32_t key_exp_index, const S
931931
CPubKey pubkey(data);
932932
if (pubkey.IsFullyValid()) {
933933
if (permit_uncompressed || pubkey.IsCompressed()) {
934-
return std::make_unique<ConstPubkeyProvider>(key_exp_index, pubkey);
934+
return std::make_unique<ConstPubkeyProvider>(key_exp_index, pubkey, false);
935935
} else {
936936
error = "Uncompressed keys are not allowed";
937937
return nullptr;
@@ -952,7 +952,7 @@ std::unique_ptr<PubkeyProvider> ParsePubkeyInner(uint32_t key_exp_index, const S
952952
if (permit_uncompressed || key.IsCompressed()) {
953953
CPubKey pubkey = key.GetPubKey();
954954
out.keys.emplace(pubkey.GetID(), key);
955-
return std::make_unique<ConstPubkeyProvider>(key_exp_index, pubkey);
955+
return std::make_unique<ConstPubkeyProvider>(key_exp_index, pubkey, ctx == ParseScriptContext::P2TR);
956956
} else {
957957
error = "Uncompressed keys are not allowed";
958958
return nullptr;
@@ -1221,42 +1221,66 @@ std::unique_ptr<DescriptorImpl> ParseScript(uint32_t& key_exp_index, Span<const
12211221

12221222
std::unique_ptr<PubkeyProvider> InferPubkey(const CPubKey& pubkey, ParseScriptContext, const SigningProvider& provider)
12231223
{
1224-
std::unique_ptr<PubkeyProvider> key_provider = std::make_unique<ConstPubkeyProvider>(0, pubkey);
1224+
std::unique_ptr<PubkeyProvider> key_provider = std::make_unique<ConstPubkeyProvider>(0, pubkey, false);
12251225
KeyOriginInfo info;
12261226
if (provider.GetKeyOrigin(pubkey.GetID(), info)) {
12271227
return std::make_unique<OriginPubkeyProvider>(0, std::move(info), std::move(key_provider));
12281228
}
12291229
return key_provider;
12301230
}
12311231

1232+
std::unique_ptr<PubkeyProvider> InferXOnlyPubkey(const XOnlyPubKey& xkey, ParseScriptContext ctx, const SigningProvider& provider)
1233+
{
1234+
unsigned char full_key[CPubKey::COMPRESSED_SIZE] = {0x02};
1235+
std::copy(xkey.begin(), xkey.end(), full_key + 1);
1236+
CPubKey pubkey(full_key);
1237+
std::unique_ptr<PubkeyProvider> key_provider = std::make_unique<ConstPubkeyProvider>(0, pubkey, true);
1238+
KeyOriginInfo info;
1239+
if (provider.GetKeyOrigin(pubkey.GetID(), info)) {
1240+
return std::make_unique<OriginPubkeyProvider>(0, std::move(info), std::move(key_provider));
1241+
} else {
1242+
full_key[0] = 0x03;
1243+
pubkey = CPubKey(full_key);
1244+
if (provider.GetKeyOrigin(pubkey.GetID(), info)) {
1245+
return std::make_unique<OriginPubkeyProvider>(0, std::move(info), std::move(key_provider));
1246+
}
1247+
}
1248+
return key_provider;
1249+
}
1250+
12321251
std::unique_ptr<DescriptorImpl> InferScript(const CScript& script, ParseScriptContext ctx, const SigningProvider& provider)
12331252
{
1253+
if (ctx == ParseScriptContext::P2TR && script.size() == 34 && script[0] == 32 && script[33] == OP_CHECKSIG) {
1254+
XOnlyPubKey key{Span<const unsigned char>{script.data() + 1, script.data() + 33}};
1255+
return std::make_unique<PKDescriptor>(InferXOnlyPubkey(key, ctx, provider));
1256+
}
1257+
12341258
std::vector<std::vector<unsigned char>> data;
12351259
TxoutType txntype = Solver(script, data);
12361260

1237-
if (txntype == TxoutType::PUBKEY) {
1261+
if (txntype == TxoutType::PUBKEY && (ctx == ParseScriptContext::TOP || ctx == ParseScriptContext::P2SH || ctx == ParseScriptContext::P2WSH)) {
12381262
CPubKey pubkey(data[0]);
12391263
if (pubkey.IsValid()) {
12401264
return std::make_unique<PKDescriptor>(InferPubkey(pubkey, ctx, provider));
12411265
}
12421266
}
1243-
if (txntype == TxoutType::PUBKEYHASH) {
1267+
if (txntype == TxoutType::PUBKEYHASH && (ctx == ParseScriptContext::TOP || ctx == ParseScriptContext::P2SH || ctx == ParseScriptContext::P2WSH)) {
12441268
uint160 hash(data[0]);
12451269
CKeyID keyid(hash);
12461270
CPubKey pubkey;
12471271
if (provider.GetPubKey(keyid, pubkey)) {
12481272
return std::make_unique<PKHDescriptor>(InferPubkey(pubkey, ctx, provider));
12491273
}
12501274
}
1251-
if (txntype == TxoutType::WITNESS_V0_KEYHASH && ctx != ParseScriptContext::P2WSH) {
1275+
if (txntype == TxoutType::WITNESS_V0_KEYHASH && (ctx == ParseScriptContext::TOP || ctx == ParseScriptContext::P2SH)) {
12521276
uint160 hash(data[0]);
12531277
CKeyID keyid(hash);
12541278
CPubKey pubkey;
12551279
if (provider.GetPubKey(keyid, pubkey)) {
12561280
return std::make_unique<WPKHDescriptor>(InferPubkey(pubkey, ctx, provider));
12571281
}
12581282
}
1259-
if (txntype == TxoutType::MULTISIG) {
1283+
if (txntype == TxoutType::MULTISIG && (ctx == ParseScriptContext::TOP || ctx == ParseScriptContext::P2SH || ctx == ParseScriptContext::P2WSH)) {
12601284
std::vector<std::unique_ptr<PubkeyProvider>> providers;
12611285
for (size_t i = 1; i + 1 < data.size(); ++i) {
12621286
CPubKey pubkey(data[i]);
@@ -1273,7 +1297,7 @@ std::unique_ptr<DescriptorImpl> InferScript(const CScript& script, ParseScriptCo
12731297
if (sub) return std::make_unique<SHDescriptor>(std::move(sub));
12741298
}
12751299
}
1276-
if (txntype == TxoutType::WITNESS_V0_SCRIPTHASH && ctx != ParseScriptContext::P2WSH) {
1300+
if (txntype == TxoutType::WITNESS_V0_SCRIPTHASH && (ctx == ParseScriptContext::TOP || ctx == ParseScriptContext::P2SH)) {
12771301
CScriptID scriptid;
12781302
CRIPEMD160().Write(data[0].data(), data[0].size()).Finalize(scriptid.begin());
12791303
CScript subscript;
@@ -1282,6 +1306,40 @@ std::unique_ptr<DescriptorImpl> InferScript(const CScript& script, ParseScriptCo
12821306
if (sub) return std::make_unique<WSHDescriptor>(std::move(sub));
12831307
}
12841308
}
1309+
if (txntype == TxoutType::WITNESS_V1_TAPROOT && ctx == ParseScriptContext::TOP) {
1310+
// Extract x-only pubkey from output.
1311+
XOnlyPubKey pubkey;
1312+
std::copy(data[0].begin(), data[0].end(), pubkey.begin());
1313+
// Request spending data.
1314+
TaprootSpendData tap;
1315+
if (provider.GetTaprootSpendData(pubkey, tap)) {
1316+
// If found, convert it back to tree form.
1317+
auto tree = InferTaprootTree(tap, pubkey);
1318+
if (tree) {
1319+
// If that works, try to infer subdescriptors for all leaves.
1320+
bool ok = true;
1321+
std::vector<std::unique_ptr<DescriptorImpl>> subscripts; //!< list of script subexpressions
1322+
std::vector<int> depths; //!< depth in the tree of each subexpression (same length subscripts)
1323+
for (const auto& [depth, script, leaf_ver] : *tree) {
1324+
std::unique_ptr<DescriptorImpl> subdesc;
1325+
if (leaf_ver == TAPROOT_LEAF_TAPSCRIPT) {
1326+
subdesc = InferScript(script, ParseScriptContext::P2TR, provider);
1327+
}
1328+
if (!subdesc) {
1329+
ok = false;
1330+
break;
1331+
} else {
1332+
subscripts.push_back(std::move(subdesc));
1333+
depths.push_back(depth);
1334+
}
1335+
}
1336+
if (ok) {
1337+
auto key = InferXOnlyPubkey(tap.internal_key, ParseScriptContext::P2TR, provider);
1338+
return std::make_unique<TRDescriptor>(std::move(key), std::move(subscripts), std::move(depths));
1339+
}
1340+
}
1341+
}
1342+
}
12851343

12861344
CTxDestination dest;
12871345
if (ExtractDestination(script, dest)) {

src/script/standard.cpp

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,3 +520,138 @@ TaprootSpendData TaprootBuilder::GetSpendData() const
520520
}
521521
return spd;
522522
}
523+
524+
std::optional<std::vector<std::tuple<int, CScript, int>>> InferTaprootTree(const TaprootSpendData& spenddata, const XOnlyPubKey& output)
525+
{
526+
// Verify that the output matches the assumed Merkle root and internal key.
527+
auto tweak = spenddata.internal_key.CreateTapTweak(spenddata.merkle_root.IsNull() ? nullptr : &spenddata.merkle_root);
528+
if (!tweak || tweak->first != output) return std::nullopt;
529+
// If the Merkle root is 0, the tree is empty, and we're done.
530+
std::vector<std::tuple<int, CScript, int>> ret;
531+
if (spenddata.merkle_root.IsNull()) return ret;
532+
533+
/** Data structure to represent the nodes of the tree we're going to be build. */
534+
struct TreeNode {
535+
/** Hash of this none, if known; 0 otherwise. */
536+
uint256 hash;
537+
/** The left and right subtrees (note that their order is irrelevant). */
538+
std::unique_ptr<TreeNode> sub[2];
539+
/** If this is known to be a leaf node, a pointer to the (script, leaf_ver) pair.
540+
* nullptr otherwise. */
541+
const std::pair<CScript, int>* leaf = nullptr;
542+
/** Whether or not this node has been explored (is known to be a leaf, or known to have children). */
543+
bool explored = false;
544+
/** Whether or not this node is an inner node (unknown until explored = true). */
545+
bool inner;
546+
/** Whether or not we have produced output for this subtree. */
547+
bool done = false;
548+
};
549+
550+
// Build tree from the provides branches.
551+
TreeNode root;
552+
root.hash = spenddata.merkle_root;
553+
for (const auto& [key, control_blocks] : spenddata.scripts) {
554+
const auto& [script, leaf_ver] = key;
555+
for (const auto& control : control_blocks) {
556+
// Skip script records with nonsensical leaf version.
557+
if (leaf_ver < 0 || leaf_ver >= 0x100 || leaf_ver & 1) continue;
558+
// Skip script records with invalid control block sizes.
559+
if (control.size() < TAPROOT_CONTROL_BASE_SIZE || control.size() > TAPROOT_CONTROL_MAX_SIZE ||
560+
((control.size() - TAPROOT_CONTROL_BASE_SIZE) % TAPROOT_CONTROL_NODE_SIZE) != 0) continue;
561+
// Skip script records that don't match the control block.
562+
if ((control[0] & TAPROOT_LEAF_MASK) != leaf_ver) continue;
563+
// Skip script records that don't match the provided Merkle root.
564+
const uint256 leaf_hash = ComputeTapleafHash(leaf_ver, script);
565+
const uint256 merkle_root = ComputeTaprootMerkleRoot(control, leaf_hash);
566+
if (merkle_root != spenddata.merkle_root) continue;
567+
568+
TreeNode* node = &root;
569+
size_t levels = (control.size() - TAPROOT_CONTROL_BASE_SIZE) / TAPROOT_CONTROL_NODE_SIZE;
570+
for (size_t depth = 0; depth < levels; ++depth) {
571+
// Can't descend into a node which we already know is a leaf.
572+
if (node->explored && !node->inner) return std::nullopt;
573+
574+
// Extract partner hash from Merkle branch in control block.
575+
uint256 hash;
576+
std::copy(control.begin() + TAPROOT_CONTROL_BASE_SIZE + (levels - 1 - depth) * TAPROOT_CONTROL_NODE_SIZE,
577+
control.begin() + TAPROOT_CONTROL_BASE_SIZE + (levels - depth) * TAPROOT_CONTROL_NODE_SIZE,
578+
hash.begin());
579+
580+
if (node->sub[0]) {
581+
// Descend into the existing left or right branch.
582+
bool desc = false;
583+
for (int i = 0; i < 2; ++i) {
584+
if (node->sub[i]->hash == hash || (node->sub[i]->hash.IsNull() && node->sub[1-i]->hash != hash)) {
585+
node->sub[i]->hash = hash;
586+
node = &*node->sub[1-i];
587+
desc = true;
588+
break;
589+
}
590+
}
591+
if (!desc) return std::nullopt; // This probably requires a hash collision to hit.
592+
} else {
593+
// We're in an unexplored node. Create subtrees and descend.
594+
node->explored = true;
595+
node->inner = true;
596+
node->sub[0] = std::make_unique<TreeNode>();
597+
node->sub[1] = std::make_unique<TreeNode>();
598+
node->sub[1]->hash = hash;
599+
node = &*node->sub[0];
600+
}
601+
}
602+
// Cannot turn a known inner node into a leaf.
603+
if (node->sub[0]) return std::nullopt;
604+
node->explored = true;
605+
node->inner = false;
606+
node->leaf = &key;
607+
node->hash = leaf_hash;
608+
}
609+
}
610+
611+
// Recursive processing to turn the tree into flattened output. Use an explicit stack here to avoid
612+
// overflowing the call stack (the tree may be 128 levels deep).
613+
std::vector<TreeNode*> stack{&root};
614+
while (!stack.empty()) {
615+
TreeNode& node = *stack.back();
616+
if (!node.explored) {
617+
// Unexplored node, which means the tree is incomplete.
618+
return std::nullopt;
619+
} else if (!node.inner) {
620+
// Leaf node; produce output.
621+
ret.emplace_back(stack.size() - 1, node.leaf->first, node.leaf->second);
622+
node.done = true;
623+
stack.pop_back();
624+
} else if (node.sub[0]->done && !node.sub[1]->done && !node.sub[1]->explored && !node.sub[1]->hash.IsNull() &&
625+
(CHashWriter{HASHER_TAPBRANCH} << node.sub[1]->hash << node.sub[1]->hash).GetSHA256() == node.hash) {
626+
// Whenever there are nodes with two identical subtrees under it, we run into a problem:
627+
// the control blocks for the leaves underneath those will be identical as well, and thus
628+
// they will all be matched to the same path in the tree. The result is that at the location
629+
// where the duplicate occurred, the left child will contain a normal tree that can be explored
630+
// and processed, but the right one will remain unexplored.
631+
//
632+
// This situation can be detected, by encountering an inner node with unexplored right subtree
633+
// with known hash, and H_TapBranch(hash, hash) is equal to the parent node (this node)'s hash.
634+
//
635+
// To deal with this, simply process the left tree a second time (set its done flag to false;
636+
// noting that the done flag of its children have already been set to false after processing
637+
// those). To avoid ending up in an infinite loop, set the done flag of the right (unexplored)
638+
// subtree to true.
639+
node.sub[0]->done = false;
640+
node.sub[1]->done = true;
641+
} else if (node.sub[0]->done && node.sub[1]->done) {
642+
// An internal node which we're finished with.
643+
node.sub[0]->done = false;
644+
node.sub[1]->done = false;
645+
node.done = true;
646+
stack.pop_back();
647+
} else if (!node.sub[0]->done) {
648+
// An internal node whose left branch hasn't been processed yet. Do so first.
649+
stack.push_back(&*node.sub[0]);
650+
} else if (!node.sub[1]->done) {
651+
// An internal node whose right branch hasn't been processed yet. Do so first.
652+
stack.push_back(&*node.sub[1]);
653+
}
654+
}
655+
656+
return ret;
657+
}

src/script/standard.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,4 +327,12 @@ class TaprootBuilder
327327
TaprootSpendData GetSpendData() const;
328328
};
329329

330+
/** Given a TaprootSpendData and the output key, reconstruct its script tree.
331+
*
332+
* If the output doesn't match the spenddata, or if the data in spenddata is incomplete,
333+
* std::nullopt is returned. Otherwise, a vector of (depth, script, leaf_ver) tuples is
334+
* returned, corresponding to a depth-first traversal of the script tree.
335+
*/
336+
std::optional<std::vector<std::tuple<int, CScript, int>>> InferTaprootTree(const TaprootSpendData& spenddata, const XOnlyPubKey& output);
337+
330338
#endif // BITCOIN_SCRIPT_STANDARD_H

test/functional/wallet_taproot.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,12 @@ def do_test_addr(self, comment, pattern, privmap, treefn, keys):
230230
if treefn is not None:
231231
addr_r = self.make_addr(treefn, keys, i)
232232
assert_equal(addr_g, addr_r)
233+
desc_a = self.addr_gen.getaddressinfo(addr_g)['desc']
234+
if desc.startswith("tr("):
235+
assert desc_a.startswith("tr(")
236+
rederive = self.nodes[1].deriveaddresses(desc_a)
237+
assert_equal(len(rederive), 1)
238+
assert_equal(rederive[0], addr_g)
233239

234240
# tr descriptors cannot be imported when Taproot is not active
235241
result = self.privs_tr_enabled.importdescriptors([{"desc": desc, "timestamp": "now"}])
@@ -374,13 +380,34 @@ def run_test(self):
374380
None,
375381
2
376382
)
383+
self.do_test(
384+
"tr(XPRV,{XPUB,XPUB})",
385+
"tr($1/*,{pk($2/*),pk($2/*)})",
386+
[True, False],
387+
lambda k1, k2: (key(k1), [pk(k2), pk(k2)]),
388+
2
389+
)
390+
self.do_test(
391+
"tr(XPRV,{{XPUB,H},{H,XPUB}})",
392+
"tr($1/*,{{pk($2/*),pk($H)},{pk($H),pk($2/*)}})",
393+
[True, False],
394+
lambda k1, k2: (key(k1), [[pk(k2), pk(H_POINT)], [pk(H_POINT), pk(k2)]]),
395+
2
396+
)
377397
self.do_test(
378398
"tr(XPUB,{{H,{H,XPUB}},{H,{H,{H,XPRV}}}})",
379399
"tr($1/*,{{pk($H),{pk($H),pk($2/*)}},{pk($H),{pk($H),{pk($H),pk($3/*)}}}})",
380400
[False, False, True],
381401
lambda k1, k2, k3: (key(k1), [[pk(H_POINT), [pk(H_POINT), pk(k2)]], [pk(H_POINT), [pk(H_POINT), [pk(H_POINT), pk(k3)]]]]),
382402
3
383403
)
404+
self.do_test(
405+
"tr(XPRV,{XPUB,{{XPUB,{H,H}},{{H,H},XPUB}}})",
406+
"tr($1/*,{pk($2/*),{{pk($2/*),{pk($H),pk($H)}},{{pk($H),pk($H)},pk($2/*)}}})",
407+
[True, False],
408+
lambda k1, k2: (key(k1), [pk(k2), [[pk(k2), [pk(H_POINT), pk(H_POINT)]], [[pk(H_POINT), pk(H_POINT)], pk(k2)]]]),
409+
2
410+
)
384411

385412
self.log.info("Sending everything back...")
386413

0 commit comments

Comments
 (0)