@@ -17,6 +17,8 @@ use std::{
1717 sync:: Arc ,
1818} ;
1919
20+ #[ cfg( not( feature = "embedded-db" ) ) ]
21+ use anyhow:: Context ;
2022use ark_serialize:: CanonicalDeserialize ;
2123use async_trait:: async_trait;
2224use futures:: stream:: TryStreamExt ;
@@ -341,6 +343,7 @@ impl<Mode: TransactionMode> Transaction<Mode> {
341343}
342344
343345// TODO: create a generic upsert function with retries that returns the column
346+ #[ cfg( feature = "embedded-db" ) ]
344347pub ( crate ) fn build_hash_batch_insert (
345348 hashes : & [ Vec < u8 > ] ,
346349) -> QueryResult < ( QueryBuilder < ' _ > , String ) > {
@@ -357,6 +360,165 @@ pub(crate) fn build_hash_batch_insert(
357360 Ok ( ( query, sql) )
358361}
359362
363+ /// Batch insert hashes using UNNEST for large batches (postgres only).
364+ /// Returns a map from hash bytes to their database IDs.
365+ #[ cfg( not( feature = "embedded-db" ) ) ]
366+ pub ( crate ) async fn batch_insert_hashes (
367+ hashes : Vec < Vec < u8 > > ,
368+ tx : & mut Transaction < Write > ,
369+ ) -> QueryResult < HashMap < Vec < u8 > , i32 > > {
370+ if hashes. is_empty ( ) {
371+ return Ok ( HashMap :: new ( ) ) ;
372+ }
373+
374+ // Use UNNEST-based batch insert (more efficient and avoids parameter limits)
375+ let sql = "INSERT INTO hash(value) SELECT * FROM UNNEST($1::bytea[]) ON CONFLICT (value) DO \
376+ UPDATE SET value = EXCLUDED.value RETURNING value, id";
377+
378+ let result: HashMap < Vec < u8 > , i32 > = sqlx:: query_as ( sql)
379+ . bind ( & hashes)
380+ . fetch ( tx. as_mut ( ) )
381+ . try_collect ( )
382+ . await
383+ . map_err ( |e| QueryError :: Error {
384+ message : format ! ( "batch hash insert failed: {e}" ) ,
385+ } ) ?;
386+
387+ Ok ( result)
388+ }
389+
390+ /// Type alias for a merkle proof with its traversal path.
391+ pub ( crate ) type ProofWithPath < Entry , Key , T , const ARITY : usize > =
392+ ( MerkleProof < Entry , Key , T , ARITY > , Vec < usize > ) ;
393+
394+ /// Collects nodes and hashes from merkle proofs.
395+ /// Returns (nodes, hashes) for batch insertion.
396+ pub ( crate ) fn collect_nodes_from_proofs < Entry , Key , T , const ARITY : usize > (
397+ proofs : & [ ProofWithPath < Entry , Key , T , ARITY > ] ,
398+ ) -> QueryResult < ( Vec < NodeWithHashes > , HashSet < Vec < u8 > > ) >
399+ where
400+ Entry : jf_merkle_tree_compat:: Element + serde:: Serialize ,
401+ Key : jf_merkle_tree_compat:: Index + serde:: Serialize ,
402+ T : jf_merkle_tree_compat:: NodeValue ,
403+ {
404+ let mut nodes = Vec :: new ( ) ;
405+ let mut hashes = HashSet :: new ( ) ;
406+
407+ for ( proof, traversal_path) in proofs {
408+ let pos = & proof. pos ;
409+ let path = & proof. proof ;
410+ let mut trav_path = traversal_path. iter ( ) . map ( |n| * n as i32 ) ;
411+
412+ for node in path. iter ( ) {
413+ match node {
414+ MerkleNode :: Empty => {
415+ let index =
416+ serde_json:: to_value ( pos. clone ( ) ) . map_err ( |e| QueryError :: Error {
417+ message : format ! ( "malformed merkle position: {e}" ) ,
418+ } ) ?;
419+ let node_path: Vec < i32 > = trav_path. clone ( ) . rev ( ) . collect ( ) ;
420+ nodes. push ( (
421+ Node {
422+ path : node_path. into ( ) ,
423+ idx : Some ( index) ,
424+ ..Default :: default ( )
425+ } ,
426+ None ,
427+ [ 0_u8 ; 32 ] . to_vec ( ) ,
428+ ) ) ;
429+ hashes. insert ( [ 0_u8 ; 32 ] . to_vec ( ) ) ;
430+ } ,
431+ MerkleNode :: ForgettenSubtree { .. } => {
432+ return Err ( QueryError :: Error {
433+ message : "Node in the Merkle path contains a forgotten subtree" . into ( ) ,
434+ } ) ;
435+ } ,
436+ MerkleNode :: Leaf { value, pos, elem } => {
437+ let mut leaf_commit = Vec :: new ( ) ;
438+ value. serialize_compressed ( & mut leaf_commit) . map_err ( |e| {
439+ QueryError :: Error {
440+ message : format ! ( "malformed merkle leaf commitment: {e}" ) ,
441+ }
442+ } ) ?;
443+
444+ let node_path: Vec < i32 > = trav_path. clone ( ) . rev ( ) . collect ( ) ;
445+
446+ let index =
447+ serde_json:: to_value ( pos. clone ( ) ) . map_err ( |e| QueryError :: Error {
448+ message : format ! ( "malformed merkle position: {e}" ) ,
449+ } ) ?;
450+ let entry = serde_json:: to_value ( elem) . map_err ( |e| QueryError :: Error {
451+ message : format ! ( "malformed merkle element: {e}" ) ,
452+ } ) ?;
453+
454+ nodes. push ( (
455+ Node {
456+ path : node_path. into ( ) ,
457+ idx : Some ( index) ,
458+ entry : Some ( entry) ,
459+ ..Default :: default ( )
460+ } ,
461+ None ,
462+ leaf_commit. clone ( ) ,
463+ ) ) ;
464+
465+ hashes. insert ( leaf_commit) ;
466+ } ,
467+ MerkleNode :: Branch { value, children } => {
468+ let mut branch_hash = Vec :: new ( ) ;
469+ value. serialize_compressed ( & mut branch_hash) . map_err ( |e| {
470+ QueryError :: Error {
471+ message : format ! ( "malformed merkle branch hash: {e}" ) ,
472+ }
473+ } ) ?;
474+
475+ let mut children_bitvec = BitVec :: new ( ) ;
476+ let mut children_values = Vec :: new ( ) ;
477+ for child in children {
478+ let child = child. as_ref ( ) ;
479+ match child {
480+ MerkleNode :: Empty => {
481+ children_bitvec. push ( false ) ;
482+ } ,
483+ MerkleNode :: Branch { value, .. }
484+ | MerkleNode :: Leaf { value, .. }
485+ | MerkleNode :: ForgettenSubtree { value } => {
486+ let mut hash = Vec :: new ( ) ;
487+ value. serialize_compressed ( & mut hash) . map_err ( |e| {
488+ QueryError :: Error {
489+ message : format ! ( "malformed merkle node hash: {e}" ) ,
490+ }
491+ } ) ?;
492+
493+ children_values. push ( hash) ;
494+ children_bitvec. push ( true ) ;
495+ } ,
496+ }
497+ }
498+
499+ let node_path: Vec < i32 > = trav_path. clone ( ) . rev ( ) . collect ( ) ;
500+ nodes. push ( (
501+ Node {
502+ path : node_path. into ( ) ,
503+ children : None ,
504+ children_bitvec : Some ( children_bitvec) ,
505+ ..Default :: default ( )
506+ } ,
507+ Some ( children_values. clone ( ) ) ,
508+ branch_hash. clone ( ) ,
509+ ) ) ;
510+ hashes. insert ( branch_hash) ;
511+ hashes. extend ( children_values) ;
512+ } ,
513+ }
514+
515+ trav_path. next ( ) ;
516+ }
517+ }
518+
519+ Ok ( ( nodes, hashes) )
520+ }
521+
360522// Represents a row in a state table
361523#[ derive( Debug , Default , Clone ) ]
362524pub ( crate ) struct Node {
@@ -369,6 +531,10 @@ pub(crate) struct Node {
369531 pub ( crate ) entry : Option < JsonValue > ,
370532}
371533
534+ /// Type alias for node data with optional children hashes and node hash.
535+ /// Used during batch collection before database insertion.
536+ pub ( crate ) type NodeWithHashes = ( Node , Option < Vec < Vec < u8 > > > , Vec < u8 > ) ;
537+
372538#[ cfg( feature = "embedded-db" ) ]
373539impl From < sqlx:: sqlite:: SqliteRow > for Node {
374540 fn from ( row : sqlx:: sqlite:: SqliteRow ) -> Self {
@@ -409,40 +575,116 @@ impl Node {
409575 nodes : impl IntoIterator < Item = Self > ,
410576 tx : & mut Transaction < Write > ,
411577 ) -> anyhow:: Result < ( ) > {
412- tx. upsert (
413- name,
414- [
415- "path" ,
416- "created" ,
417- "hash_id" ,
418- "children" ,
419- "children_bitvec" ,
420- "idx" ,
421- "entry" ,
422- ] ,
423- [ "path" , "created" ] ,
424- nodes. into_iter ( ) . map ( |n| {
425- #[ cfg( feature = "embedded-db" ) ]
426- let children_bitvec: Option < String > = n
427- . children_bitvec
428- . clone ( )
429- . map ( |b| b. iter ( ) . map ( |bit| if bit { '1' } else { '0' } ) . collect ( ) ) ;
430-
431- #[ cfg( not( feature = "embedded-db" ) ) ]
432- let children_bitvec = n. children_bitvec . clone ( ) ;
433-
434- (
435- n. path . clone ( ) ,
436- n. created ,
437- n. hash_id ,
438- n. children . clone ( ) ,
439- children_bitvec,
440- n. idx . clone ( ) ,
441- n. entry . clone ( ) ,
578+ let nodes: Vec < _ > = nodes. into_iter ( ) . collect ( ) ;
579+
580+ // Use UNNEST-based batch insert for postgres (more efficient and avoids parameter limits)
581+ #[ cfg( not( feature = "embedded-db" ) ) ]
582+ return Self :: upsert_batch_unnest ( name, nodes, tx) . await ;
583+
584+ #[ cfg( feature = "embedded-db" ) ]
585+ {
586+ for node_chunk in nodes. chunks ( 20 ) {
587+ let rows: Vec < _ > = node_chunk
588+ . iter ( )
589+ . map ( |n| {
590+ let children_bitvec: Option < String > = n
591+ . children_bitvec
592+ . clone ( )
593+ . map ( |b| b. iter ( ) . map ( |bit| if bit { '1' } else { '0' } ) . collect ( ) ) ;
594+
595+ (
596+ n. path . clone ( ) ,
597+ n. created ,
598+ n. hash_id ,
599+ n. children . clone ( ) ,
600+ children_bitvec,
601+ n. idx . clone ( ) ,
602+ n. entry . clone ( ) ,
603+ )
604+ } )
605+ . collect ( ) ;
606+
607+ tx. upsert (
608+ name,
609+ [
610+ "path" ,
611+ "created" ,
612+ "hash_id" ,
613+ "children" ,
614+ "children_bitvec" ,
615+ "idx" ,
616+ "entry" ,
617+ ] ,
618+ [ "path" , "created" ] ,
619+ rows,
442620 )
443- } ) ,
444- )
445- . await
621+ . await ?;
622+ }
623+ Ok ( ( ) )
624+ }
625+ }
626+
627+ #[ cfg( not( feature = "embedded-db" ) ) ]
628+ async fn upsert_batch_unnest (
629+ name : & str ,
630+ nodes : Vec < Self > ,
631+ tx : & mut Transaction < Write > ,
632+ ) -> anyhow:: Result < ( ) > {
633+ if nodes. is_empty ( ) {
634+ return Ok ( ( ) ) ;
635+ }
636+
637+ // Deduplicate nodes by (path, created) - keep the last occurrence
638+ // This is required because UNNEST + ON CONFLICT cannot handle duplicates in the same batch
639+ let mut deduped = HashMap :: new ( ) ;
640+ for node in nodes {
641+ deduped. insert ( ( node. path . to_string ( ) , node. created ) , node) ;
642+ }
643+
644+ let mut paths = Vec :: with_capacity ( deduped. len ( ) ) ;
645+ let mut createds = Vec :: with_capacity ( deduped. len ( ) ) ;
646+ let mut hash_ids = Vec :: with_capacity ( deduped. len ( ) ) ;
647+ let mut childrens = Vec :: with_capacity ( deduped. len ( ) ) ;
648+ let mut children_bitvecs = Vec :: with_capacity ( deduped. len ( ) ) ;
649+ let mut idxs = Vec :: with_capacity ( deduped. len ( ) ) ;
650+ let mut entries = Vec :: with_capacity ( deduped. len ( ) ) ;
651+
652+ for node in deduped. into_values ( ) {
653+ paths. push ( node. path ) ;
654+ createds. push ( node. created ) ;
655+ hash_ids. push ( node. hash_id ) ;
656+ childrens. push ( node. children ) ;
657+ children_bitvecs. push ( node. children_bitvec ) ;
658+ idxs. push ( node. idx ) ;
659+ entries. push ( node. entry ) ;
660+ }
661+
662+ let sql = format ! (
663+ r#"
664+ INSERT INTO "{name}" (path, created, hash_id, children, children_bitvec, idx, entry)
665+ SELECT * FROM UNNEST($1::jsonb[], $2::bigint[], $3::int[], $4::jsonb[], $5::bit varying[], $6::jsonb[], $7::jsonb[])
666+ ON CONFLICT (path, created) DO UPDATE SET
667+ hash_id = EXCLUDED.hash_id,
668+ children = EXCLUDED.children,
669+ children_bitvec = EXCLUDED.children_bitvec,
670+ idx = EXCLUDED.idx,
671+ entry = EXCLUDED.entry
672+ "#
673+ ) ;
674+
675+ sqlx:: query ( & sql)
676+ . bind ( & paths)
677+ . bind ( & createds)
678+ . bind ( & hash_ids)
679+ . bind ( & childrens)
680+ . bind ( & children_bitvecs)
681+ . bind ( & idxs)
682+ . bind ( & entries)
683+ . execute ( tx. as_mut ( ) )
684+ . await
685+ . context ( "batch upsert with UNNEST failed" ) ?;
686+
687+ Ok ( ( ) )
446688 }
447689}
448690
0 commit comments