@@ -12,6 +12,7 @@ fn on_load() {
1212}
1313
1414
15+ use rust_core:: RustPDAG ;
1516
1617#[ extendr]
1718#[ derive( Debug , Clone ) ]
@@ -416,10 +417,233 @@ impl RIndependencies {
416417 }
417418}
418419
420+ #[ extendr]
421+ #[ derive( Debug , Clone ) ]
422+ pub struct PDAG {
423+ inner : RustPDAG ,
424+ }
425+
426+
427+ #[ extendr]
428+ impl PDAG {
429+ /// Create a new PDAG
430+ /// @export
431+ fn new ( ) -> Self {
432+ PDAG { inner : RustPDAG :: new ( ) }
433+ }
434+
435+ /// Add a single node
436+ /// @param node Node name
437+ /// @param latent Whether latent (default FALSE)
438+ /// @export
439+ fn add_node ( & mut self , node : String , latent : Option < bool > ) -> extendr_api:: Result < ( ) > {
440+ self . inner . add_node ( node, latent. unwrap_or ( false ) )
441+ . map_err ( |e| Error :: Other ( e. to_string ( ) ) )
442+ }
443+
444+ /// Add nodes from vector with optional latent mask (NULL means all false)
445+ /// @param nodes character vector
446+ /// @param latent NULL or logical vector
447+ /// @export
448+ fn add_nodes_from ( & mut self , nodes : Strings , latent : Nullable < Logicals > ) -> extendr_api:: Result < ( ) > {
449+ let node_vec: Vec < String > = nodes. iter ( ) . map ( |s| s. to_string ( ) ) . collect ( ) ;
450+ let latent_opt: Option < Vec < bool > > = latent. into_option ( ) . map ( |v| v. iter ( ) . map ( |x| x. is_true ( ) ) . collect ( ) ) ;
451+ self . inner . add_nodes_from ( node_vec, latent_opt) . map_err ( |e| Error :: Other ( e. to_string ( ) ) )
452+ }
453+
454+ /// Add single edge (directed or undirected)
455+ /// @param u source
456+ /// @param v target
457+ /// @param weight optional numeric (NULL)
458+ /// @param directed bool (TRUE: directed, FALSE: undirected)
459+ /// @export
460+ fn add_edge ( & mut self , u : String , v : String , weight : Nullable < f64 > , directed : Option < bool > ) -> extendr_api:: Result < ( ) > {
461+ let w = weight. into_option ( ) ;
462+ let d = directed. unwrap_or ( true ) ;
463+ self . inner . add_edge ( u, v, w, d) . map_err ( |e| Error :: Other ( e. to_string ( ) ) )
464+ }
465+
466+ /// Add multiple edges from an R list of pairs: list(c("A","B"), c("C","D"))
467+ /// @param ebunch list of character vectors length 2
468+ /// @param weights NULL or numeric vector
469+ /// @param directed bool
470+ /// @export
471+ fn add_edges_from ( & mut self , ebunch : List , weights : Nullable < Doubles > , directed : Option < bool > ) -> extendr_api:: Result < ( ) > {
472+ // convert ebunch (List) -> Vec<(String,String)>
473+ let mut edges: Vec < ( String , String ) > = Vec :: with_capacity ( ebunch. len ( ) ) ;
474+ for ( i, item) in ebunch. values ( ) . enumerate ( ) {
475+ // Each item must be a character vector of length 2
476+ let pair: Strings = item. try_into ( ) . map_err ( |_| Error :: Other ( format ! ( "ebunch[{}] must be a character vector of length 2" , i) ) ) ?;
477+ if pair. len ( ) != 2 {
478+ return Err ( Error :: Other ( format ! ( "ebunch[{}] must have exactly 2 elements" , i) ) ) ;
479+ }
480+ edges. push ( ( pair[ 0 ] . to_string ( ) , pair[ 1 ] . to_string ( ) ) ) ;
481+ }
482+ let weight_opt: Option < Vec < f64 > > = weights. into_option ( ) . map ( |v| v. iter ( ) . map ( |d| d. inner ( ) ) . collect ( ) ) ;
483+ let directed = directed. unwrap_or ( true ) ;
484+ self . inner . add_edges_from ( Some ( edges) , weight_opt, directed) . map_err ( |e| Error :: Other ( e. to_string ( ) ) )
485+ }
486+
487+ /// Return all edges. For PDAG this includes both directed and undirected (both directions placed into graph).
488+ /// Return as list(from = ..., to = ...) same as RDAG$edges()
489+ /// @export
490+ fn edges ( & self ) -> List {
491+ let edges = self . inner . edges ( ) ;
492+ let ( from, to) : ( Vec < _ > , Vec < _ > ) = edges. into_iter ( ) . unzip ( ) ;
493+ list ! ( from = from, to = to)
494+ }
495+
496+ /// Return nodes
497+ /// @export
498+ fn nodes ( & self ) -> Strings {
499+ self . inner . nodes ( ) . iter ( ) . map ( |s| s. as_str ( ) ) . collect :: < Strings > ( )
500+ }
501+
502+ /// Number of nodes
503+ /// @export
504+ fn node_count ( & self ) -> i32 {
505+ self . inner . node_map . len ( ) as i32
506+ }
507+
508+ /// Number of edges (count unique graph edges)
509+ /// @export
510+ fn edge_count ( & self ) -> i32 {
511+ self . inner . edges ( ) . len ( ) as i32
512+ }
513+
514+ /// Latent nodes
515+ /// @export
516+ fn latents ( & self ) -> Strings {
517+ let mut v: Vec < String > = self . inner . latents . iter ( ) . cloned ( ) . collect ( ) ;
518+ v. sort ( ) ;
519+ v. iter ( ) . map ( |s| s. as_str ( ) ) . collect :: < Strings > ( )
520+ }
521+
522+ /// Directed edges as a list of 2-element character vectors
523+ /// @export
524+ fn directed_edges ( & self ) -> List {
525+ let mut vec = self . inner . directed_edges . iter ( ) . cloned ( ) . collect :: < Vec < _ > > ( ) ;
526+ vec. sort ( ) ;
527+ let mut out = List :: new ( vec. len ( ) ) ;
528+ for ( i, ( u, v) ) in vec. into_iter ( ) . enumerate ( ) {
529+ let pair = vec ! [ u. as_str( ) , v. as_str( ) ] . iter ( ) . map ( |s| * s) . collect :: < Strings > ( ) ;
530+ out. set_elt ( i, Into :: < Robj > :: into ( pair) ) . unwrap ( ) ;
531+ }
532+ out
533+ }
534+
535+ /// Undirected edges reported as stored (u, v) for each undirected pair (original insertion)
536+ /// @export
537+ fn undirected_edges ( & self ) -> List {
538+ let mut vec = self . inner . undirected_edges . iter ( ) . cloned ( ) . collect :: < Vec < _ > > ( ) ;
539+ vec. sort ( ) ;
540+ let mut out = List :: new ( vec. len ( ) ) ;
541+ for ( i, ( u, v) ) in vec. into_iter ( ) . enumerate ( ) {
542+ let pair = vec ! [ u. as_str( ) , v. as_str( ) ] . iter ( ) . map ( |s| * s) . collect :: < Strings > ( ) ;
543+ out. set_elt ( i, Into :: < Robj > :: into ( pair) ) . unwrap ( ) ;
544+ }
545+ out
546+ }
547+
548+ /// All neighbors (directed or undirected) as character vector
549+ /// @export
550+ fn all_neighbors ( & self , node : String ) -> extendr_api:: Result < Strings > {
551+ let s = self . inner . all_neighbors ( & node) . map_err ( |e| Error :: Other ( e) ) ?;
552+ let mut v: Vec < String > = s. into_iter ( ) . collect ( ) ;
553+ v. sort ( ) ;
554+ Ok ( v. iter ( ) . map ( |x| x. as_str ( ) ) . collect :: < Strings > ( ) )
555+ }
556+
557+ /// Directed children
558+ /// @export
559+ fn directed_children ( & self , node : String ) -> extendr_api:: Result < Strings > {
560+ let s = self . inner . directed_children ( & node) . map_err ( |e| Error :: Other ( e) ) ?;
561+ let mut v: Vec < String > = s. into_iter ( ) . collect ( ) ;
562+ v. sort ( ) ;
563+ Ok ( v. iter ( ) . map ( |x| x. as_str ( ) ) . collect :: < Strings > ( ) )
564+ }
565+
566+ /// Directed parents
567+ /// @export
568+ fn directed_parents ( & self , node : String ) -> extendr_api:: Result < Strings > {
569+ let s = self . inner . directed_parents ( & node) . map_err ( |e| Error :: Other ( e) ) ?;
570+ let mut v: Vec < String > = s. into_iter ( ) . collect ( ) ;
571+ v. sort ( ) ;
572+ Ok ( v. iter ( ) . map ( |x| x. as_str ( ) ) . collect :: < Strings > ( ) )
573+ }
574+
575+ /// has_directed_edge
576+ /// @export
577+ fn has_directed_edge ( & self , u : String , v : String ) -> bool {
578+ self . inner . has_directed_edge ( & u, & v)
579+ }
580+
581+ /// has_undirected_edge
582+ /// @export
583+ fn has_undirected_edge ( & self , u : String , v : String ) -> bool {
584+ self . inner . has_undirected_edge ( & u, & v)
585+ }
586+
587+ /// undirected_neighbors
588+ /// @export
589+ fn undirected_neighbors ( & self , node : String ) -> extendr_api:: Result < Strings > {
590+ let s = self . inner . undirected_neighbors ( & node) . map_err ( |e| Error :: Other ( e) ) ?;
591+ let mut v: Vec < String > = s. into_iter ( ) . collect ( ) ;
592+ v. sort ( ) ;
593+ Ok ( v. iter ( ) . map ( |x| x. as_str ( ) ) . collect :: < Strings > ( ) )
594+ }
595+
596+ /// is_adjacent
597+ /// @export
598+ fn is_adjacent ( & self , u : String , v : String ) -> bool {
599+ self . inner . is_adjacent ( & u, & v)
600+ }
601+
602+ /// copy
603+ /// @export
604+ fn copy ( & self ) -> PDAG {
605+ PDAG { inner : self . inner . copy ( ) }
606+ }
607+
608+ /// orient_undirected_edge (returns NULL if inplace = TRUE, otherwise returns new PDAG)
609+ /// @param u
610+ /// @param v
611+ /// @param inplace default TRUE
612+ /// @export
613+ fn orient_undirected_edge ( & mut self , u : String , v : String , inplace : Option < bool > ) -> extendr_api:: Result < Nullable < PDAG > > {
614+ let in_place = inplace. unwrap_or ( true ) ;
615+ match self . inner . orient_undirected_edge ( & u, & v, in_place) {
616+ Ok ( None ) => Ok ( Nullable :: Null ) ,
617+ Ok ( Some ( pdag) ) => Ok ( Nullable :: NotNull ( PDAG { inner : pdag } ) ) ,
618+ Err ( e) => Err ( Error :: Other ( e) ) ,
619+ }
620+ }
621+
622+ /// apply_meeks_rules (apply_r4 bool, inplace bool)
623+ /// @export
624+ fn apply_meeks_rules ( & mut self , apply_r4 : Option < bool > , inplace : Option < bool > ) -> extendr_api:: Result < Nullable < PDAG > > {
625+ let apply_r4 = apply_r4. unwrap_or ( true ) ;
626+ let inplace = inplace. unwrap_or ( false ) ;
627+ match self . inner . apply_meeks_rules ( apply_r4, inplace) {
628+ Ok ( None ) => Ok ( Nullable :: Null ) ,
629+ Ok ( Some ( pdag) ) => Ok ( Nullable :: NotNull ( PDAG { inner : pdag } ) ) ,
630+ Err ( e) => Err ( Error :: Other ( e) ) ,
631+ }
632+ }
633+
634+ /// to_dag -> RDAG
635+ /// @export
636+ fn to_dag ( & self ) -> extendr_api:: Result < RDAG > {
637+ let dag = self . inner . to_dag ( ) . map_err ( |e| Error :: Other ( e) ) ?;
638+ Ok ( RDAG { inner : dag } )
639+ }
640+ }
641+
419642
420643extendr_module ! {
421644 mod causalgraphs;
422645 impl RDAG ;
423646 impl RIndependenceAssertion ;
424647 impl RIndependencies ;
648+ impl PDAG ;
425649}
0 commit comments