Skip to content

WIP: independence tests for categorical variables #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ Tables
SpecialFunctions
NearestNeighbors
Distances
MetaGraphs
MetaGraphs
DataStructures
CategoricalArrays
Lazy
6 changes: 6 additions & 0 deletions docs/src/library.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@ kl_mutual_information
kl_cond_mi
kl_perm_mi_test
kl_perm_cond_mi_test
cat_H
cat_MI
cat_CMI
perm_cat_MI_test
perm_cat_CMI_test
```

## FCI algorithm
```@docs
has_marks
Expand Down
1 change: 1 addition & 0 deletions src/CausalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export n_ball
export fcialg, is_collider, is_triangle, is_parent
export is_discriminating_path, has_marks, set_marks!, is_uncovered_circle_path
export is_uncovered_PD_path, @arrow_str
export cat_H, cat_MI, cat_CMI, perm_cat_MI_test, perm_cat_CMI_test
export plot_pc_graph, plot_fci_graph
export orient_unshielded, orientable_unshielded, apply_pc_rules

Expand Down
86 changes: 84 additions & 2 deletions src/klentropy.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using SpecialFunctions, NearestNeighbors, Distances, Distributions, Random

using SpecialFunctions, NearestNeighbors, Distances, Distributions, Random, CategoricalArrays
using DataStructures: counter
using Lazy: @>>
"""
n_ball(n::Number)
Computes the volume of a n-dimensional unit sphere.
Expand Down Expand Up @@ -225,3 +226,84 @@ function kl_perm_cond_mi_test(x, y, z; k=5, B=100, kp=5, bias_correction=true)

return p
end

"""
cat_H(x)

estimate entropy of categorical variable x

For further information, see:

"Entropy Estimates from Insufficient Samplings"
Peter Grassberger
https://arxiv.org/abs/physics/0307138v2
"""
function cat_H(x)
N = length(x)
counts = values(counter(x))
return log(N) - 1/N * sum(map(n->n*(digamma(n) + (-1)^n * 1/(n*(n+1))), counts))
end

"""
cat_MI(x, y)

estimate mutual information of categorical variables x and y
"""
function cat_MI(x, y)
return cat_H(x) + cat_H(y) - cat_H(zip(x,y))
end

"""
cat_MI(x, y, z)

estimate conditional mutual information of categorical variables x and y given z
"""
function cat_CMI(x, y, z)
return cat_H(zip(x,z)) + cat_H(zip(y,z)) - cat_H(zip(x,y,z)) - cat_H(z)
end

"""
perm_cat_MI_test(x, y; B=100)

perform permutation-based independence test of x and y
"""
function perm_cat_MI_test(x, y; B=100)
samples = Float64[]
MI = cat_MI(x,y)

for i in 1:B
push!(samples, cat_MI(shuffle(x), y))
end

p = length(filter(d->MI<d, samples))/B
return p
end

"""
perm_cat_CMI_test(x, y, z; B=100)

perform permutation-based conditional independence test of x and y given z
"""
function perm_cat_CMI_test(x, y, z; B=100)
samples = Float64[]
CMI = cat_CMI(x,y,z)

xz = collect(zip(x,z))

for i in 1:B
x_shuffle_dict = Dict()
for level in unique(z)
x_shuffle_dict[level] = @>> xz filter(d->d[2]==level) map(d->d[1]) shuffle
end

x_shuffle = []
for zs in z
push!(x_shuffle, pop!(x_shuffle_dict[zs]))
end

push!(samples, cat_CMI(x_shuffle, y, z))
end

p = length(filter(d->CMI<d, samples))/B
return p
end
5 changes: 3 additions & 2 deletions src/pc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,14 @@ conditional independeces using a conditional mutual information permutation test
"""
function pcalg(t, p::Float64, test::typeof(cmitest); kwargs...)
@assert Tables.istable(t)
@assert all(t->t==Float64, Tables.schema(t).types)
# @assert all(t->t==Float64, Tables.schema(t).types)


c = Tables.columns(t)
sch = Tables.schema(t)
n = length(sch.names)

return pcalg(n, cmitest, c, p; kwargs...)
return pcalg(n, cmitest, c, sch, p; kwargs...)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that defined?

end


Expand Down
29 changes: 28 additions & 1 deletion src/skeleton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ end


"""
cmitest(i,j,s,data,crit; kwargs...)
cmitest(i, j, s, data, [schema,] crit; kwargs...)

Test for conditional independence of variables i and j given variables in s with
permutation test using nearest neighbor conditional mutual information estimates
Expand All @@ -125,6 +125,33 @@ at p-value crit.
keyword arguments:
kwargs...: keyword arguments passed to independence tests
"""
@inline function cmitest(i, j, s, data, schema, crit; kwargs...)
if all(t -> t==Float64, schema.types)
x=collect(transpose(convert(Array, data[i])))
y=collect(transpose(convert(Array, data[j])))

if length(s)==0
res = kl_perm_mi_test(x, y; kwargs...)
else
z = reduce(vcat, map(c->collect(transpose(convert(Array, data[c]))), s))
res = kl_perm_cond_mi_test(x, y, z; kwargs...)
end

#@debug "CMI test for $(i)-$(j) given $(s): $(res) compared to $(crit)"
return res>crit
elseif all(t -> t<:CategoricalValue, schema.types)
x = data[i]
y = data[j]

if length(s)==0
res = perm_cat_MI_test(x, y; kwargs...)
else
z = zip(map(c->data[c], s)...)
res = perm_cat_CMI_test(x, y, z; kwargs...)
end
return res>crit
end
end
@inline function cmitest(i, j, s, data, crit; kwargs...)
columns = Tables.columns(data)

Expand Down
15 changes: 15 additions & 0 deletions test/klentropy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,18 @@ end
collect(transpose(x)))>p

end

@testset "Categorical Variables" begin
Random.seed!(123)
N = 1000
p = 0.05

x = rand(1:5, N)
@test abs((cat_H(x) - log(5))/log(5)) < 0.01

x = rand(1:4, N)
v = map(d->floor(d/2), x) .+ rand(1:2, N)
w = map(d->floor(d/2), x) .+ rand(1:2, N)
@test perm_cat_MI_test(v,w) < p
@test perm_cat_CMI_test(v,w,x) > p
end
18 changes: 17 additions & 1 deletion test/pc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Test
using Random
using LightGraphs
using Distributions
using CategoricalArrays
using CausalInference: disjoint_sorted

@test disjoint_sorted([],[1,2])
Expand Down Expand Up @@ -58,7 +59,22 @@ println("Running Gaussian tests")
println("Running CMI tests")
@time cmi_g = pcalg(df, 0.1, cmitest)

@testset "pcalg_edgde_test" begin
@testset "continuous pcalg_edge_test" begin
@test collect(LightGraphs.edges(cmi_g)) == collect(LightGraphs.edges(dg))
@test collect(LightGraphs.edges(gaussci_g)) == collect(LightGraphs.edges(dg))
end

x = rand(1:5, N)
v = x + rand(1:2, N)
w = x + rand(1:2, N)
z = ((v + w)/2) + rand(1:2, N)
s = z + rand(1:2, N)

df = (x=CategoricalArray(x), v=CategoricalArray(v),
w=CategoricalArray(w), z=CategoricalArray(z),
s=CategoricalArray(s))

@time cmi_cat_g = pcalg(df, 0.1, cmitest)
@testset "categorical pcalg_edge_test" begin
@test collect(LightGraphs.edges(cmi_cat_g)) == collect(LightGraphs.edges(dg))
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ using Test
include("skeleton.jl")
include("dsep.jl")
include("pc.jl")
include("dsep.jl")
include("skeleton.jl")
include("cpdag.jl")
include("combinations.jl")
include("witness.jl")
Expand Down
1 change: 0 additions & 1 deletion test/skeleton.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

using CausalInference
using LightGraphs
using Test
Expand Down