Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ Tables
SpecialFunctions
NearestNeighbors
Distances
MetaGraphs
MetaGraphs
DataStructures
CategoricalArrays
Lazy
6 changes: 6 additions & 0 deletions docs/src/library.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ kl_mutual_information
kl_cond_mi
kl_perm_mi_test
kl_perm_cond_mi_test
cat_H
cat_MI
cat_CMI
perm_cat_MI_test
perm_cat_CMI_test
```

## FCI algorithm
```@docs
has_marks
Expand Down
1 change: 1 addition & 0 deletions src/CausalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export n_ball
export fcialg, is_collider, is_triangle, is_parent
export is_discriminating_path, has_marks, set_marks!, is_uncovered_circle_path
export is_uncovered_PD_path, @arrow_str
export cat_H, cat_MI, cat_CMI, perm_cat_MI_test, perm_cat_CMI_test

include("klentropy.jl")
include("skeleton.jl")
Expand Down
86 changes: 84 additions & 2 deletions src/klentropy.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using SpecialFunctions, NearestNeighbors, Distances, Distributions, Random

using SpecialFunctions, NearestNeighbors, Distances, Distributions, Random, CategoricalArrays
using DataStructures: counter
using Lazy: @>>
"""
n_ball(n::Number)
Computes the volume of a n-dimensional unit sphere.
Expand Down Expand Up @@ -225,3 +226,84 @@ function kl_perm_cond_mi_test(x, y, z; k=5, B=100, kp=5, bias_correction=true)

return p
end

"""
cat_H(x)

estimate entropy of categorical variable x

For further information, see:

"Entropy Estimates from Insufficient Samplings"
Peter Grassberger
https://arxiv.org/abs/physics/0307138v2
"""
function cat_H(x)
N = length(x)
counts = values(counter(x))
return log(N) - 1/N * sum(map(n->n*(digamma(n) + (-1)^n * 1/(n*(n+1))), counts))
end

"""
cat_MI(x, y)

estimate mutual information of categorical variables x and y
"""
function cat_MI(x, y)
return cat_H(x) + cat_H(y) - cat_H(zip(x,y))
end

"""
cat_MI(x, y, z)

estimate conditional mutual information of categorical variables x and y given z
"""
function cat_CMI(x, y, z)
return cat_H(zip(x,z)) + cat_H(zip(y,z)) - cat_H(zip(x,y,z)) - cat_H(z)
end

"""
perm_cat_MI_test(x, y; B=100)

perform permutation-based independence test of x and y
"""
function perm_cat_MI_test(x, y; B=100)
samples = Float64[]
MI = cat_MI(x,y)

for i in 1:B
push!(samples, cat_MI(shuffle(x), y))
end

p = length(filter(d->MI<d, samples))/B
return p
end

"""
perm_cat_CMI_test(x, y, z; B=100)

perform permutation-based conditional independence test of x and y given z
"""
function perm_cat_CMI_test(x, y, z; B=100)
samples = Float64[]
CMI = cat_CMI(x,y,z)

xz = collect(zip(x,z))

for i in 1:B
x_shuffle_dict = Dict()
for level in unique(z)
x_shuffle_dict[level] = @>> xz filter(d->d[2]==level) map(d->d[1]) shuffle
end

x_shuffle = []
for zs in z
push!(x_shuffle, pop!(x_shuffle_dict[zs]))
end

push!(samples, cat_CMI(x_shuffle, y, z))
end

p = length(filter(d->CMI<d, samples))/B
return p
end
5 changes: 3 additions & 2 deletions src/pc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,12 @@ conditional independeces using a conditional mutual information permutation test
"""
function pcalg(t, p::Float64, test::typeof(cmitest); kwargs...)
@assert Tables.istable(t)
@assert all(t->t==Float64, Tables.schema(t).types)
# @assert all(t->t==Float64, Tables.schema(t).types)


c = Tables.columns(t)
sch = Tables.schema(t)
n = length(sch.names)

return pcalg(n, cmitest, c, p; kwargs...)
return pcalg(n, cmitest, c, sch, p; kwargs...)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that defined?

end
35 changes: 24 additions & 11 deletions src/skeleton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,32 @@ at p-value crit.
keyword arguments:
kwargs...: keyword arguments passed to independence tests
"""
@inline function cmitest(i, j, s, data, crit; kwargs...)
x=collect(transpose(convert(Array, data[i])))
y=collect(transpose(convert(Array, data[j])))
@inline function cmitest(i, j, s, data, schema, crit; kwargs...)
if all(t -> t==Float64, schema.types)
x=collect(transpose(convert(Array, data[i])))
y=collect(transpose(convert(Array, data[j])))

if length(s)==0
res = kl_perm_mi_test(x, y; kwargs...)
else
z = reduce(vcat, map(c->collect(transpose(convert(Array, data[c]))), s))
res = kl_perm_cond_mi_test(x, y, z; kwargs...)
end
if length(s)==0
res = kl_perm_mi_test(x, y; kwargs...)
else
z = reduce(vcat, map(c->collect(transpose(convert(Array, data[c]))), s))
res = kl_perm_cond_mi_test(x, y, z; kwargs...)
end

#@debug "CMI test for $(i)-$(j) given $(s): $(res) compared to $(crit)"
return res>crit
#@debug "CMI test for $(i)-$(j) given $(s): $(res) compared to $(crit)"
return res>crit
elseif all(t -> t<:CategoricalValue, schema.types)
x = data[i]
y = data[j]

if length(s)==0
res = perm_cat_MI_test(x, y; kwargs...)
else
z = zip(map(c->data[c], s)...)
res = perm_cat_CMI_test(x, y, z; kwargs...)
end
return res>crit
end
end

truetest(i, j, s) = true
Expand Down
15 changes: 15 additions & 0 deletions test/klentropy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,18 @@ end
collect(transpose(x)))>p

end

@testset "Categorical Variables" begin
Random.seed!(123)
N = 1000
p = 0.05

x = rand(1:5, N)
@test abs((cat_H(x) - log(5))/log(5)) < 0.01

x = rand(1:4, N)
v = map(d->floor(d/2), x) .+ rand(1:2, N)
w = map(d->floor(d/2), x) .+ rand(1:2, N)
@test perm_cat_MI_test(v,w) < p
@test perm_cat_CMI_test(v,w,x) > p
end
18 changes: 17 additions & 1 deletion test/pc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Test
using Random
using LightGraphs
using Distributions
using CategoricalArrays
using CausalInference: disjoint_sorted

@test disjoint_sorted([],[1,2])
Expand Down Expand Up @@ -58,7 +59,22 @@ println("Running Gaussian tests")
println("Running CMI tests")
@time cmi_g = pcalg(df, 0.1, cmitest)

@testset "pcalg_edgde_test" begin
@testset "continuous pcalg_edge_test" begin
@test collect(LightGraphs.edges(cmi_g)) == collect(LightGraphs.edges(dg))
@test collect(LightGraphs.edges(gaussci_g)) == collect(LightGraphs.edges(dg))
end

x = rand(1:5, N)
v = x + rand(1:2, N)
w = x + rand(1:2, N)
z = ((v + w)/2) + rand(1:2, N)
s = z + rand(1:2, N)

df = (x=CategoricalArray(x), v=CategoricalArray(v),
w=CategoricalArray(w), z=CategoricalArray(z),
s=CategoricalArray(s))

@time cmi_cat_g = pcalg(df, 0.1, cmitest)
@testset "categorical pcalg_edge_test" begin
@test collect(LightGraphs.edges(cmi_cat_g)) == collect(LightGraphs.edges(dg))
end
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ using Test

include(joinpath("..", "docs", "make.jl"))

include("fci.jl")
include("klentropy.jl")
include("skeleton.jl")
include("dsep.jl")
include("pc.jl")
include("dsep.jl")
include("skeleton.jl")
include("cpdag.jl")
include("fci.jl")
1 change: 0 additions & 1 deletion test/skeleton.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

using CausalInference
using LightGraphs
using Test
Expand Down