diff --git a/REQUIRE b/REQUIRE index 88756ff1..c71e2796 100644 --- a/REQUIRE +++ b/REQUIRE @@ -6,4 +6,7 @@ Tables SpecialFunctions NearestNeighbors Distances -MetaGraphs \ No newline at end of file +MetaGraphs +DataStructures +CategoricalArrays +Lazy \ No newline at end of file diff --git a/docs/src/library.md b/docs/src/library.md index f28be266..109ead47 100644 --- a/docs/src/library.md +++ b/docs/src/library.md @@ -41,7 +41,13 @@ kl_mutual_information kl_cond_mi kl_perm_mi_test kl_perm_cond_mi_test +cat_H +cat_MI +cat_CMI +perm_cat_MI_test +perm_cat_CMI_test ``` + ## FCI algorithm ```@docs has_marks diff --git a/src/CausalInference.jl b/src/CausalInference.jl index 7dd3a27d..e03f1df8 100644 --- a/src/CausalInference.jl +++ b/src/CausalInference.jl @@ -14,6 +14,7 @@ export n_ball export fcialg, is_collider, is_triangle, is_parent export is_discriminating_path, has_marks, set_marks!, is_uncovered_circle_path export is_uncovered_PD_path, @arrow_str +export cat_H, cat_MI, cat_CMI, perm_cat_MI_test, perm_cat_CMI_test export plot_pc_graph, plot_fci_graph export orient_unshielded, orientable_unshielded, apply_pc_rules diff --git a/src/klentropy.jl b/src/klentropy.jl index 2f79c4c9..6d24fb0f 100644 --- a/src/klentropy.jl +++ b/src/klentropy.jl @@ -1,5 +1,6 @@ -using SpecialFunctions, NearestNeighbors, Distances, Distributions, Random - +using SpecialFunctions, NearestNeighbors, Distances, Distributions, Random, CategoricalArrays +using DataStructures: counter +using Lazy: @>> """ n_ball(n::Number) Computes the volume of a n-dimensional unit sphere. @@ -225,3 +226,84 @@ function kl_perm_cond_mi_test(x, y, z; k=5, B=100, kp=5, bias_correction=true) return p end + +""" + cat_H(x) + +estimate entropy of categorical variable x + +For further information, see: + +"Entropy Estimates from Insufficient Samplings" +Peter Grassberger +https://arxiv.org/abs/physics/0307138v2 +""" +function cat_H(x) + N = length(x) + counts = values(counter(x)) + return log(N) - 1/N * sum(map(n->n*(digamma(n) + (-1)^n * 1/(n*(n+1))), counts)) +end + +""" + cat_MI(x, y) + +estimate mutual information of categorical variables x and y +""" +function cat_MI(x, y) + return cat_H(x) + cat_H(y) - cat_H(zip(x,y)) +end + +""" + cat_MI(x, y, z) + +estimate conditional mutual information of categorical variables x and y given z +""" +function cat_CMI(x, y, z) + return cat_H(zip(x,z)) + cat_H(zip(y,z)) - cat_H(zip(x,y,z)) - cat_H(z) +end + +""" + perm_cat_MI_test(x, y; B=100) + +perform permutation-based independence test of x and y +""" +function perm_cat_MI_test(x, y; B=100) + samples = Float64[] + MI = cat_MI(x,y) + + for i in 1:B + push!(samples, cat_MI(shuffle(x), y)) + end + + p = length(filter(d->MI> xz filter(d->d[2]==level) map(d->d[1]) shuffle + end + + x_shuffle = [] + for zs in z + push!(x_shuffle, pop!(x_shuffle_dict[zs])) + end + + push!(samples, cat_CMI(x_shuffle, y, z)) + end + + p = length(filter(d->CMIt==Float64, Tables.schema(t).types) + # @assert all(t->t==Float64, Tables.schema(t).types) + c = Tables.columns(t) sch = Tables.schema(t) n = length(sch.names) - return pcalg(n, cmitest, c, p; kwargs...) + return pcalg(n, cmitest, c, sch, p; kwargs...) end diff --git a/src/skeleton.jl b/src/skeleton.jl index d7aca6da..c58abb71 100644 --- a/src/skeleton.jl +++ b/src/skeleton.jl @@ -116,7 +116,7 @@ end """ - cmitest(i,j,s,data,crit; kwargs...) + cmitest(i, j, s, data, [schema,] crit; kwargs...) Test for conditional independence of variables i and j given variables in s with permutation test using nearest neighbor conditional mutual information estimates @@ -125,6 +125,33 @@ at p-value crit. keyword arguments: kwargs...: keyword arguments passed to independence tests """ +@inline function cmitest(i, j, s, data, schema, crit; kwargs...) + if all(t -> t==Float64, schema.types) + x=collect(transpose(convert(Array, data[i]))) + y=collect(transpose(convert(Array, data[j]))) + + if length(s)==0 + res = kl_perm_mi_test(x, y; kwargs...) + else + z = reduce(vcat, map(c->collect(transpose(convert(Array, data[c]))), s)) + res = kl_perm_cond_mi_test(x, y, z; kwargs...) + end + + #@debug "CMI test for $(i)-$(j) given $(s): $(res) compared to $(crit)" + return res>crit + elseif all(t -> t<:CategoricalValue, schema.types) + x = data[i] + y = data[j] + + if length(s)==0 + res = perm_cat_MI_test(x, y; kwargs...) + else + z = zip(map(c->data[c], s)...) + res = perm_cat_CMI_test(x, y, z; kwargs...) + end + return res>crit + end +end @inline function cmitest(i, j, s, data, crit; kwargs...) columns = Tables.columns(data) diff --git a/test/klentropy.jl b/test/klentropy.jl index 3d19a06d..54ddd697 100644 --- a/test/klentropy.jl +++ b/test/klentropy.jl @@ -53,3 +53,18 @@ end collect(transpose(x)))>p end + +@testset "Categorical Variables" begin + Random.seed!(123) + N = 1000 + p = 0.05 + + x = rand(1:5, N) + @test abs((cat_H(x) - log(5))/log(5)) < 0.01 + + x = rand(1:4, N) + v = map(d->floor(d/2), x) .+ rand(1:2, N) + w = map(d->floor(d/2), x) .+ rand(1:2, N) + @test perm_cat_MI_test(v,w) < p + @test perm_cat_CMI_test(v,w,x) > p +end diff --git a/test/pc.jl b/test/pc.jl index 6ea528d5..cd2ae533 100644 --- a/test/pc.jl +++ b/test/pc.jl @@ -4,6 +4,7 @@ using Test using Random using LightGraphs using Distributions +using CategoricalArrays using CausalInference: disjoint_sorted @test disjoint_sorted([],[1,2]) @@ -58,7 +59,22 @@ println("Running Gaussian tests") println("Running CMI tests") @time cmi_g = pcalg(df, 0.1, cmitest) -@testset "pcalg_edgde_test" begin +@testset "continuous pcalg_edge_test" begin @test collect(LightGraphs.edges(cmi_g)) == collect(LightGraphs.edges(dg)) @test collect(LightGraphs.edges(gaussci_g)) == collect(LightGraphs.edges(dg)) end + +x = rand(1:5, N) +v = x + rand(1:2, N) +w = x + rand(1:2, N) +z = ((v + w)/2) + rand(1:2, N) +s = z + rand(1:2, N) + +df = (x=CategoricalArray(x), v=CategoricalArray(v), + w=CategoricalArray(w), z=CategoricalArray(z), + s=CategoricalArray(s)) + +@time cmi_cat_g = pcalg(df, 0.1, cmitest) +@testset "categorical pcalg_edge_test" begin + @test collect(LightGraphs.edges(cmi_cat_g)) == collect(LightGraphs.edges(dg)) +end diff --git a/test/runtests.jl b/test/runtests.jl index 64053e31..fdf7e3a1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,8 @@ using Test include("skeleton.jl") include("dsep.jl") include("pc.jl") +include("dsep.jl") +include("skeleton.jl") include("cpdag.jl") include("combinations.jl") include("witness.jl") diff --git a/test/skeleton.jl b/test/skeleton.jl index b8d8136c..001f2241 100644 --- a/test/skeleton.jl +++ b/test/skeleton.jl @@ -1,4 +1,3 @@ - using CausalInference using LightGraphs using Test