diff --git a/docs/src/library.md b/docs/src/library.md index 8e3f97a5..3d489394 100644 --- a/docs/src/library.md +++ b/docs/src/library.md @@ -31,6 +31,8 @@ CausalInference.meek_rule2 CausalInference.meek_rule3 CausalInference.meek_rule4 pdag2dag! +pdag_to_dag_dortasi! +pdag_to_dag_meek! ``` ## PC algorithm diff --git a/src/CausalInference.jl b/src/CausalInference.jl index 521f97a3..1d4b70ea 100644 --- a/src/CausalInference.jl +++ b/src/CausalInference.jl @@ -25,7 +25,7 @@ export plot_pc_graph_recipes, plot_fci_graph_recipes # if GraphRecipes is loaded export plot_pc_graph_tikz, plot_fci_graph_tikz # if TikzGraphs is loaded export orient_unshielded, orientable_unshielded, apply_pc_rules export ges -export pdag2dag! +export pdag2dag!, pdag_to_dag_meek!, pdag_to_dag_dortasi! #include("pinv.jl") include("graphs.jl") diff --git a/src/meek.jl b/src/meek.jl index 31a682a9..955e1b40 100644 --- a/src/meek.jl +++ b/src/meek.jl @@ -108,11 +108,11 @@ function meek_rule4(dg, v, w) end """ - pdag2dag!(g, rule4=false) + pdag_to_dag_meek!(g, rule4=false) Complete PDAG to DAG using meek_rules. """ -function pdag2dag!(g, rule4=false) +function pdag_to_dag_meek!(g, rule4=false) while true # find unoriented edge for e in edges(g) # go through edges (bad to start in the beginning?) @@ -127,4 +127,9 @@ function pdag2dag!(g, rule4=false) @label orient meek_rules!(g; rule4) end + g end +""" +Deprecated alias for `pdag_to_dag_meek!`. +""" +const pdag2dag! = pdag_to_dag_meek! \ No newline at end of file diff --git a/src/pdag.jl b/src/pdag.jl index eeb39f3d..9051e61d 100644 --- a/src/pdag.jl +++ b/src/pdag.jl @@ -147,3 +147,40 @@ Children of x in g are vertices y such that there is a directed edge y <-- x. Returns sorted array. """ children(g, x) = setdiff(outneighbors(g, x), inneighbors(g, x)) + +""" + pdag_to_dag_dortasi!!(g) + +Complete PDAG to DAG using Dor & Tasi (1992). +""" +function pdag_to_dag_dortasi!(g) + removed = falses(nv(g)) # Mark vertices removed from (sub-)graph A. Efficient if degree small? + while !all(removed) + touched = false + for x in vertices(g) + removed[x] && continue + for y in outneighbors(g, x) + removed[y] && continue + has_edge(g, y, x) || @goto skip # not a sink + end + for y in neighbors_undirected(g, x) + removed[y] && continue + for z in inneighbors(g, x) # contains all adjacents by assumption + removed[z] && continue + y==z || isadjacent(g, y, z) || @goto skip + end + end + for y in copy(outneighbors(g, x)) + removed[y] && continue + rem_edge!(g, x, y) + end + touched = true + removed[x] = true + @label skip + end + if !touched + error("PDAG has no consistent extension to a DAG") + end + end + return g +end diff --git a/test/cpdag.jl b/test/cpdag.jl index 021ec394..bdad8dbe 100644 --- a/test/cpdag.jl +++ b/test/cpdag.jl @@ -57,6 +57,16 @@ for stable in (true, false) h1 = pc_oracle(g; stable) h2 = cpdag(g) + + g2 = pdag_to_dag_dortasi!(copy(h2)) + @test !is_cyclic(g2) + @test h2 == cpdag(g2) + + g2 = pdag_to_dag_meek!(copy(h2)) + @test !is_cyclic(g2) + @test h2 == cpdag(g2) + + h1 == h2 || println(vpairs(g)) @test vpairs(h1) ⊆ vpairs(h2) @test vpairs(h2) ⊆ vpairs(h1)