Skip to content

Commit a9d08e8

Browse files
feat: pdag bindings [wasm + r] (#26)
* initial pdag impl * * orient edges * meeks rule * pdag to dag * deterministic sorting * initial commit * pdag to dag more tests * add latent method to py DAG * fix direction typo * fix meeks rule 3 * minor fixes * major dfs fix * refactor * add PDAG wasm bindings * Pdag R bindings --------- Co-authored-by: Ankur Ankan <[email protected]>
1 parent e1432db commit a9d08e8

File tree

7 files changed

+618
-0
lines changed

7 files changed

+618
-0
lines changed

r_bindings/causalgraphs/NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method("$",PDAG)
34
S3method("$",RDAG)
45
S3method("$",RIndependenceAssertion)
56
S3method("$",RIndependencies)
7+
S3method("[[",PDAG)
68
S3method("[[",RDAG)
79
S3method("[[",RIndependenceAssertion)
810
S3method("[[",RIndependencies)

r_bindings/causalgraphs/R/extendr-wrappers.R

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,5 +106,59 @@ RIndependencies$is_equivalent <- function(other) .Call(wrap__RIndependencies__is
106106
#' @export
107107
`[[.RIndependencies` <- `$.RIndependencies`
108108

109+
PDAG <- new.env(parent = emptyenv())
110+
111+
PDAG$new <- function() .Call(wrap__PDAG__new)
112+
113+
PDAG$add_node <- function(node, latent) .Call(wrap__PDAG__add_node, self, node, latent)
114+
115+
PDAG$add_nodes_from <- function(nodes, latent) .Call(wrap__PDAG__add_nodes_from, self, nodes, latent)
116+
117+
PDAG$add_edge <- function(u, v, weight, directed) .Call(wrap__PDAG__add_edge, self, u, v, weight, directed)
118+
119+
PDAG$add_edges_from <- function(ebunch, weights, directed) .Call(wrap__PDAG__add_edges_from, self, ebunch, weights, directed)
120+
121+
PDAG$edges <- function() .Call(wrap__PDAG__edges, self)
122+
123+
PDAG$nodes <- function() .Call(wrap__PDAG__nodes, self)
124+
125+
PDAG$node_count <- function() .Call(wrap__PDAG__node_count, self)
126+
127+
PDAG$edge_count <- function() .Call(wrap__PDAG__edge_count, self)
128+
129+
PDAG$latents <- function() .Call(wrap__PDAG__latents, self)
130+
131+
PDAG$directed_edges <- function() .Call(wrap__PDAG__directed_edges, self)
132+
133+
PDAG$undirected_edges <- function() .Call(wrap__PDAG__undirected_edges, self)
134+
135+
PDAG$all_neighbors <- function(node) .Call(wrap__PDAG__all_neighbors, self, node)
136+
137+
PDAG$directed_children <- function(node) .Call(wrap__PDAG__directed_children, self, node)
138+
139+
PDAG$directed_parents <- function(node) .Call(wrap__PDAG__directed_parents, self, node)
140+
141+
PDAG$has_directed_edge <- function(u, v) .Call(wrap__PDAG__has_directed_edge, self, u, v)
142+
143+
PDAG$has_undirected_edge <- function(u, v) .Call(wrap__PDAG__has_undirected_edge, self, u, v)
144+
145+
PDAG$undirected_neighbors <- function(node) .Call(wrap__PDAG__undirected_neighbors, self, node)
146+
147+
PDAG$is_adjacent <- function(u, v) .Call(wrap__PDAG__is_adjacent, self, u, v)
148+
149+
PDAG$copy <- function() .Call(wrap__PDAG__copy, self)
150+
151+
PDAG$orient_undirected_edge <- function(u, v, inplace) .Call(wrap__PDAG__orient_undirected_edge, self, u, v, inplace)
152+
153+
PDAG$apply_meeks_rules <- function(apply_r4, inplace) .Call(wrap__PDAG__apply_meeks_rules, self, apply_r4, inplace)
154+
155+
PDAG$to_dag <- function() .Call(wrap__PDAG__to_dag, self)
156+
157+
#' @export
158+
`$.PDAG` <- function (self, name) { func <- PDAG[[name]]; environment(func) <- environment(); func }
159+
160+
#' @export
161+
`[[.PDAG` <- `$.PDAG`
162+
109163

110164
# nolint end

r_bindings/causalgraphs/src/rust/src/lib.rs

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ fn on_load() {
1212
}
1313

1414

15+
use rust_core::RustPDAG;
1516

1617
#[extendr]
1718
#[derive(Debug, Clone)]
@@ -416,10 +417,233 @@ impl RIndependencies {
416417
}
417418
}
418419

420+
#[extendr]
421+
#[derive(Debug, Clone)]
422+
pub struct PDAG {
423+
inner: RustPDAG,
424+
}
425+
426+
427+
#[extendr]
428+
impl PDAG {
429+
/// Create a new PDAG
430+
/// @export
431+
fn new() -> Self {
432+
PDAG { inner: RustPDAG::new() }
433+
}
434+
435+
/// Add a single node
436+
/// @param node Node name
437+
/// @param latent Whether latent (default FALSE)
438+
/// @export
439+
fn add_node(&mut self, node: String, latent: Option<bool>) -> extendr_api::Result<()> {
440+
self.inner.add_node(node, latent.unwrap_or(false))
441+
.map_err(|e| Error::Other(e.to_string()))
442+
}
443+
444+
/// Add nodes from vector with optional latent mask (NULL means all false)
445+
/// @param nodes character vector
446+
/// @param latent NULL or logical vector
447+
/// @export
448+
fn add_nodes_from(&mut self, nodes: Strings, latent: Nullable<Logicals>) -> extendr_api::Result<()> {
449+
let node_vec: Vec<String> = nodes.iter().map(|s| s.to_string()).collect();
450+
let latent_opt: Option<Vec<bool>> = latent.into_option().map(|v| v.iter().map(|x| x.is_true()).collect());
451+
self.inner.add_nodes_from(node_vec, latent_opt).map_err(|e| Error::Other(e.to_string()))
452+
}
453+
454+
/// Add single edge (directed or undirected)
455+
/// @param u source
456+
/// @param v target
457+
/// @param weight optional numeric (NULL)
458+
/// @param directed bool (TRUE: directed, FALSE: undirected)
459+
/// @export
460+
fn add_edge(&mut self, u: String, v: String, weight: Nullable<f64>, directed: Option<bool>) -> extendr_api::Result<()> {
461+
let w = weight.into_option();
462+
let d = directed.unwrap_or(true);
463+
self.inner.add_edge(u, v, w, d).map_err(|e| Error::Other(e.to_string()))
464+
}
465+
466+
/// Add multiple edges from an R list of pairs: list(c("A","B"), c("C","D"))
467+
/// @param ebunch list of character vectors length 2
468+
/// @param weights NULL or numeric vector
469+
/// @param directed bool
470+
/// @export
471+
fn add_edges_from(&mut self, ebunch: List, weights: Nullable<Doubles>, directed: Option<bool>) -> extendr_api::Result<()> {
472+
// convert ebunch (List) -> Vec<(String,String)>
473+
let mut edges: Vec<(String,String)> = Vec::with_capacity(ebunch.len());
474+
for (i, item) in ebunch.values().enumerate() {
475+
// Each item must be a character vector of length 2
476+
let pair: Strings = item.try_into().map_err(|_| Error::Other(format!("ebunch[{}] must be a character vector of length 2", i)))?;
477+
if pair.len() != 2 {
478+
return Err(Error::Other(format!("ebunch[{}] must have exactly 2 elements", i)));
479+
}
480+
edges.push((pair[0].to_string(), pair[1].to_string()));
481+
}
482+
let weight_opt: Option<Vec<f64>> = weights.into_option().map(|v| v.iter().map(|d| d.inner()).collect());
483+
let directed = directed.unwrap_or(true);
484+
self.inner.add_edges_from(Some(edges), weight_opt, directed).map_err(|e| Error::Other(e.to_string()))
485+
}
486+
487+
/// Return all edges. For PDAG this includes both directed and undirected (both directions placed into graph).
488+
/// Return as list(from = ..., to = ...) same as RDAG$edges()
489+
/// @export
490+
fn edges(&self) -> List {
491+
let edges = self.inner.edges();
492+
let (from, to): (Vec<_>, Vec<_>) = edges.into_iter().unzip();
493+
list!(from = from, to = to)
494+
}
495+
496+
/// Return nodes
497+
/// @export
498+
fn nodes(&self) -> Strings {
499+
self.inner.nodes().iter().map(|s| s.as_str()).collect::<Strings>()
500+
}
501+
502+
/// Number of nodes
503+
/// @export
504+
fn node_count(&self) -> i32 {
505+
self.inner.node_map.len() as i32
506+
}
507+
508+
/// Number of edges (count unique graph edges)
509+
/// @export
510+
fn edge_count(&self) -> i32 {
511+
self.inner.edges().len() as i32
512+
}
513+
514+
/// Latent nodes
515+
/// @export
516+
fn latents(&self) -> Strings {
517+
let mut v: Vec<String> = self.inner.latents.iter().cloned().collect();
518+
v.sort();
519+
v.iter().map(|s| s.as_str()).collect::<Strings>()
520+
}
521+
522+
/// Directed edges as a list of 2-element character vectors
523+
/// @export
524+
fn directed_edges(&self) -> List {
525+
let mut vec = self.inner.directed_edges.iter().cloned().collect::<Vec<_>>();
526+
vec.sort();
527+
let mut out = List::new(vec.len());
528+
for (i, (u, v)) in vec.into_iter().enumerate() {
529+
let pair = vec![u.as_str(), v.as_str()].iter().map(|s| *s).collect::<Strings>();
530+
out.set_elt(i, Into::<Robj>::into(pair)).unwrap();
531+
}
532+
out
533+
}
534+
535+
/// Undirected edges reported as stored (u, v) for each undirected pair (original insertion)
536+
/// @export
537+
fn undirected_edges(&self) -> List {
538+
let mut vec = self.inner.undirected_edges.iter().cloned().collect::<Vec<_>>();
539+
vec.sort();
540+
let mut out = List::new(vec.len());
541+
for (i, (u, v)) in vec.into_iter().enumerate() {
542+
let pair = vec![u.as_str(), v.as_str()].iter().map(|s| *s).collect::<Strings>();
543+
out.set_elt(i, Into::<Robj>::into(pair)).unwrap();
544+
}
545+
out
546+
}
547+
548+
/// All neighbors (directed or undirected) as character vector
549+
/// @export
550+
fn all_neighbors(&self, node: String) -> extendr_api::Result<Strings> {
551+
let s = self.inner.all_neighbors(&node).map_err(|e| Error::Other(e))?;
552+
let mut v: Vec<String> = s.into_iter().collect();
553+
v.sort();
554+
Ok(v.iter().map(|x| x.as_str()).collect::<Strings>())
555+
}
556+
557+
/// Directed children
558+
/// @export
559+
fn directed_children(&self, node: String) -> extendr_api::Result<Strings> {
560+
let s = self.inner.directed_children(&node).map_err(|e| Error::Other(e))?;
561+
let mut v: Vec<String> = s.into_iter().collect();
562+
v.sort();
563+
Ok(v.iter().map(|x| x.as_str()).collect::<Strings>())
564+
}
565+
566+
/// Directed parents
567+
/// @export
568+
fn directed_parents(&self, node: String) -> extendr_api::Result<Strings> {
569+
let s = self.inner.directed_parents(&node).map_err(|e| Error::Other(e))?;
570+
let mut v: Vec<String> = s.into_iter().collect();
571+
v.sort();
572+
Ok(v.iter().map(|x| x.as_str()).collect::<Strings>())
573+
}
574+
575+
/// has_directed_edge
576+
/// @export
577+
fn has_directed_edge(&self, u: String, v: String) -> bool {
578+
self.inner.has_directed_edge(&u, &v)
579+
}
580+
581+
/// has_undirected_edge
582+
/// @export
583+
fn has_undirected_edge(&self, u: String, v: String) -> bool {
584+
self.inner.has_undirected_edge(&u, &v)
585+
}
586+
587+
/// undirected_neighbors
588+
/// @export
589+
fn undirected_neighbors(&self, node: String) -> extendr_api::Result<Strings> {
590+
let s = self.inner.undirected_neighbors(&node).map_err(|e| Error::Other(e))?;
591+
let mut v: Vec<String> = s.into_iter().collect();
592+
v.sort();
593+
Ok(v.iter().map(|x| x.as_str()).collect::<Strings>())
594+
}
595+
596+
/// is_adjacent
597+
/// @export
598+
fn is_adjacent(&self, u: String, v: String) -> bool {
599+
self.inner.is_adjacent(&u, &v)
600+
}
601+
602+
/// copy
603+
/// @export
604+
fn copy(&self) -> PDAG {
605+
PDAG { inner: self.inner.copy() }
606+
}
607+
608+
/// orient_undirected_edge (returns NULL if inplace = TRUE, otherwise returns new PDAG)
609+
/// @param u
610+
/// @param v
611+
/// @param inplace default TRUE
612+
/// @export
613+
fn orient_undirected_edge(&mut self, u: String, v: String, inplace: Option<bool>) -> extendr_api::Result<Nullable<PDAG>> {
614+
let in_place = inplace.unwrap_or(true);
615+
match self.inner.orient_undirected_edge(&u, &v, in_place) {
616+
Ok(None) => Ok(Nullable::Null),
617+
Ok(Some(pdag)) => Ok(Nullable::NotNull(PDAG { inner: pdag })),
618+
Err(e) => Err(Error::Other(e)),
619+
}
620+
}
621+
622+
/// apply_meeks_rules (apply_r4 bool, inplace bool)
623+
/// @export
624+
fn apply_meeks_rules(&mut self, apply_r4: Option<bool>, inplace: Option<bool>) -> extendr_api::Result<Nullable<PDAG>> {
625+
let apply_r4 = apply_r4.unwrap_or(true);
626+
let inplace = inplace.unwrap_or(false);
627+
match self.inner.apply_meeks_rules(apply_r4, inplace) {
628+
Ok(None) => Ok(Nullable::Null),
629+
Ok(Some(pdag)) => Ok(Nullable::NotNull(PDAG { inner: pdag })),
630+
Err(e) => Err(Error::Other(e)),
631+
}
632+
}
633+
634+
/// to_dag -> RDAG
635+
/// @export
636+
fn to_dag(&self) -> extendr_api::Result<RDAG> {
637+
let dag = self.inner.to_dag().map_err(|e| Error::Other(e))?;
638+
Ok(RDAG { inner: dag })
639+
}
640+
}
641+
419642

420643
extendr_module! {
421644
mod causalgraphs;
422645
impl RDAG;
423646
impl RIndependenceAssertion;
424647
impl RIndependencies;
648+
impl PDAG;
425649
}
File renamed without changes.

0 commit comments

Comments
 (0)