Skip to content

Commit 3c6710f

Browse files
authored
Merge pull request #110 from mschauer/dsep
Extend `dsep` for d-separation between vertex sets
2 parents 27e0161 + d6ac606 commit 3c6710f

File tree

2 files changed

+41
-11
lines changed

2 files changed

+41
-11
lines changed

src/dsep.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,19 @@ Check whether `u` and `v` are d-separated given set `s`.
88
Algorithm: unrolled https://arxiv.org/abs/1304.1505
99
1010
"""
11-
function dsep(g::AbstractGraph, u::Integer, v::Integer, S; verbose = false)
11+
function dsep(g::AbstractGraph, U, V, S; verbose = false)
1212
T = eltype(g)
1313
in_seen = falses(nv(g)) # nodes reached earlier backwards
1414
out_seen = falses(nv(g)) # nodes reached earlier forwards
1515
descendant = falses(nv(g)) # descendant in s
16+
isv = falses(nv(g))
1617
blocked = falses(nv(g))
1718

1819
for ve in S
1920
in_seen[ve] = true
2021
blocked[ve] = true
2122
end
2223

23-
(in_seen[u] || in_seen[v]) && throw(ArgumentError("S should not contain u or v"))
24-
25-
u == v && throw(ArgumentError("u == v"))
26-
2724
next = Vector{T}()
2825

2926
# mark vertices with descendants in S
@@ -47,9 +44,25 @@ function dsep(g::AbstractGraph, u::Integer, v::Integer, S; verbose = false)
4744
in_next = Vector{T}()
4845
out_next = Vector{T}()
4946

50-
push!(in_next, u) # treat u as vertex reached backwards
51-
in_seen[u] = true
47+
for u in U
48+
push!(in_next, u) # treat u as vertex reached backwards
49+
in_seen[u] && throw(ArgumentError("U and S not disjoint."))
50+
in_seen[u] = true
51+
end
52+
if V isa Integer
53+
in_seen[V] && throw(ArgumentError("U, V and S not disjoint."))
54+
return dsep_inner!(g, in_next, out_next, descendant, ==(V), blocked, out_seen, in_seen; verbose)
55+
else
56+
isv = falses(nv(g))
57+
for v in V
58+
in_seen[v] && throw(ArgumentError("U, V and S not disjoint."))
59+
isv[v] = true
60+
end
61+
return dsep_inner!(g, in_next, out_next, descendant, w->isv[w], blocked, out_seen, in_seen; verbose)
62+
end
63+
end
5264

65+
function dsep_inner!(g, in_next, out_next, descendant, found, blocked, out_seen, in_seen; verbose=false)
5366
while true
5467
sin = isempty(in_next)
5568
sout = isempty(out_next)
@@ -60,15 +73,15 @@ function dsep(g::AbstractGraph, u::Integer, v::Integer, S; verbose = false)
6073
for w in outneighbors(g, src) # possible collider at destination
6174
if !out_seen[w] && (!blocked[w] || descendant[w])
6275
verbose && println("<- $src -> $w")
63-
w == v && return false
76+
found(w) && return false
6477
push!(out_next, w)
6578
out_seen[w] = true
6679
end
6780
end
6881
for w in inneighbors(g, src)
6982
if !in_seen[w]
7083
verbose && println("<- $src <- $w")
71-
w == v && return false
84+
found(w) && return false
7285
push!(in_next, w)
7386
in_seen[w] = true
7487
end
@@ -79,15 +92,15 @@ function dsep(g::AbstractGraph, u::Integer, v::Integer, S; verbose = false)
7992
for w in outneighbors(g, src) # possible collider at destination
8093
if !out_seen[w] && !blocked[src] && (!blocked[w] || descendant[w])
8194
verbose && println("-> $src -> $w")
82-
w == v && return false
95+
found(w) && return false
8396
push!(out_next, w)
8497
out_seen[w] = true
8598
end
8699
end
87100
for w in inneighbors(g, src) # collider at source
88101
if !in_seen[w] && descendant[src] # shielded collider
89102
verbose && println("-> $src <- $w")
90-
w == v && return false
103+
found(w) && return false
91104
push!(out_next, w)
92105
in_seen[w] = true
93106
end

test/gensearch.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ using Test
77
# however, the algorithms should be deterministic, so we think the
88
# tests are okay for now
99

10+
11+
1012
g1 = SimpleDiGraph(Edge.([(1, 2), (2, 3), (3, 4), (5, 1), (6, 5), (6, 4), (1, 7)]))
1113
X = Set(1)
1214
Xint = 1
@@ -77,3 +79,18 @@ Y = Set(8)
7779
@test Set(list_backdoor_adjustment(g2, Set([6]), Set([8]), Set(Int[]), setdiff(Set(1:8), [1,2]))) == Set([Set([3,4]), Set([4,5]), Set([3,4,5])])
7880
@test Set(list_frontdoor_adjustment(g2, X, Y)) == Set([Set(7)])
7981
end
82+
83+
using Test, CausalInference, Combinatorics
84+
function test_dsep(g)
85+
n = nv(g)
86+
for (_, v, w, z) in partitions(1:n, 4)
87+
@test dsep(g, v, w, z) == alt_test_dsep(g, v, w, z)
88+
end
89+
end
90+
@testset "dsep vs alt_test_dsep" begin
91+
test_dsep(g1)
92+
test_dsep(g2)
93+
@test_throws ArgumentError dsep(g1, [1,2], [2,3], [4,5])
94+
@test_throws ArgumentError dsep(g1, [1,2], [3,4], [4,5])
95+
@test_throws ArgumentError dsep(g1, [1,2], [3,4], [5,1])
96+
end

0 commit comments

Comments
 (0)