diff --git a/Project.toml b/Project.toml index 9eb78c89..a2e83497 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.15.1" [deps] Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -26,6 +27,17 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TabularDisplay = "3eeacb1d-13c2-54cc-9b18-30c86af3cadb" ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" +[weakdeps] +GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" +GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742" + +[extensions] +GraphMakieExt = "GraphMakie" +GraphRecipesExt = ["GraphRecipes", "Plots"] +TikzGraphsExt = "TikzGraphs" + [compat] Combinatorics = "1.0" DelimitedFiles = "1.6, 1.7, 1.8, 1.9" @@ -55,11 +67,6 @@ ThreadsX = "0.1" TikzGraphs = "1.3, 1.4" julia = "1.6, 1.7, 1.8, 1.9, 1.10" -[extensions] -GraphMakieExt = "GraphMakie" -GraphRecipesExt = ["GraphRecipes", "Plots"] -TikzGraphsExt = "TikzGraphs" - [extras] DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" @@ -71,9 +78,3 @@ TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742" [targets] test = ["Test", "StatsBase", "DelimitedFiles"] - -[weakdeps] -GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" -GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742" diff --git a/src/CausalInference.jl b/src/CausalInference.jl index 777a2750..5d1f0fea 100644 --- a/src/CausalInference.jl +++ b/src/CausalInference.jl @@ -9,9 +9,11 @@ using Base.Iterators using Memoization, LRUCache using ThreadsX using LinkedLists +using DataStructures import Base: ==, show +export multisampler export exactscorebased export ancestors, descendants, alt_test_dsep, test_covariate_adjustment, alt_test_backdoor, find_dsep, find_min_dsep, find_covariate_adjustment, find_backdoor_adjustment, find_frontdoor_adjustment, find_min_covariate_adjustment, find_min_backdoor_adjustment, find_min_frontdoor_adjustment, list_dseps, list_covariate_adjustment, list_backdoor_adjustment, list_frontdoor_adjustment export dsep, skeleton, gausscitest, dseporacle, partialcor @@ -56,6 +58,7 @@ include("dag_sampler.jl") include("misc2.jl") include("exact.jl") #include("mcs.jl") +include("multisampler.jl") # Compatibility with the new "Package Extensions" (https://github.com/JuliaLang/julia/pull/47695) const EXTENSIONS_SUPPORTED = isdefined(Base, :get_extension) diff --git a/src/multisampler.jl b/src/multisampler.jl new file mode 100644 index 00000000..5b93482d --- /dev/null +++ b/src/multisampler.jl @@ -0,0 +1,181 @@ +struct Sample + g::DiGraph + τ::Float64 + dir::Int8 + total::Int32 + scoreval::Float64 + alive::Bool +end +Sample(g, nextτ, dir, total, scoreval) = Sample(g, nextτ, dir, total, scoreval, true) + +struct Action + i::Int + τ::Float64 + apply!::Function + args::Tuple{Vararg{Any}} +end + +function expcoldness(τ, k=0.0005) + return exp(k*τ) +end + +function Dexpcoldness(τ, k=0.0005) + return k*exp(k*τ) +end + +function init(_, _, nextτ, g, dir, total, scoreval) + return Sample(g, nextτ, dir, total, scoreval) +end + +function applyup(samplers, i, nextτ, x, y, T, Δscoreval) + prevsample = samplers[i] + g = next_CPDAG(prevsample.g, :up, x, y, T) + return samplers[i] = Sample(g, nextτ, prevsample.dir, prevsample.total+1, prevsample.scoreval + Δscoreval) +end + +function applydown(samplers, i, nextτ, x, y, H, Δscoreval) + prevsample = samplers[i] + g = next_CPDAG(prevsample.g, :down, x, y, H) + return samplers[i] = Sample(g, nextτ, prevsample.dir, prevsample.total-1, prevsample.scoreval + Δscoreval) +end + +function applyflip(samplers, i, nextτ) + prevsample = samplers[i] + return samplers[i] = Sample(prevsample.g, nextτ, -1*prevsample.dir, prevsample.total, prevsample.scoreval) +end + +function applycopy(samplers, i, nextτ, j) + copysample = samplers[j] + s = (i == j) ? 1 : -1 # move opposite direction + return samplers[i] = Sample(copysample.g, nextτ, s*copysample.dir, copysample.total, copysample.scoreval) +end + +function applykill(samplers, i, nextτ) + prevsample = samplers[i] + return samplers[i] = Sample(prevsample.g, nextτ, prevsample.dir, prevsample.total, prevsample.scoreval, false) +end + +function applynothing(samplers, i, nextτ) + @assert false + sample = samplers[i] + return samplers[i] = Sample(sample.g, nextτ, sample.dir, sample.total, sample.scoreval, sample.alive) +end + +# for starters without turn move + +function sampleaction(samplers, i, M, balance, prior, score, maxscoreval, σ, ρ, κ, coldness, Dcoldness, threshold, keep, force) + # preprocess + prevsample = samplers[i] + prevsample.alive || return Action(i, Inf, applynothing, ()) + + sup, sdown, Δscorevalup, Δscorevaldown, argsup, argsdown = count_moves_new(prevsample.g, κ, balance, prior, score, coldness(prevsample.τ), prevsample.total) + # propose moves + λdir = prevsample.dir == 1 ? sup : sdown + λupdown = sup + sdown + λflip = max(prevsample.dir*(-sup + sdown), 0.0) + λterm = force*exp(ULogarithmic, 0.0)*Dcoldness(prevsample.τ) * clamp(maxscoreval - prevsample.scoreval, eps(), threshold) # TODO: prior + Δτdir = randexp()/(ρ*λdir) + Δτupdown = randexp()/(σ*λupdown) + Δτflip = randexp()/(ρ*λflip) + Δτterm = randexp()/abs(λterm) + Δτmin, a = findmin((Δτdir, Δτupdown, Δτflip, Δτterm)) + A = (:dir, :updown, :flip, :term)[a] + @assert Δτmin >= 0 + if :dir == A + if prevsample.dir == 1 + return Action(i, prevsample.τ + Δτdir, applyup, (argsup..., Δscorevalup)) + else + return Action(i, prevsample.τ + Δτdir, applydown, (argsdown..., Δscorevaldown)) + end + end + + if :updown == A + λup = sup + if rand() < λup/λupdown + return Action(i, prevsample.τ + Δτupdown, applyup, (argsup..., Δscorevalup)) + else + return Action(i, prevsample.τ + Δτupdown, applydown, (argsdown..., Δscorevaldown)) + end + end + + if :flip == A + return Action(i, prevsample.τ + Δτflip, applyflip, ()) + end + + if :term == A + if rand() < keep + if keep < 1 + j = rand(findall(s.alive for s in samplers)) + else + j = rand(1:M) + end + return Action(i, prevsample.τ + Δτterm, applycopy, (j,)) + else + return Action(i, prevsample.τ + Δτterm, applykill, ()) + end + end + + @assert false +end +# remark: chose κ = n-1 as default +function multisampler(n, G = (DiGraph(n), 0); M = 10, balance = metropolis_balance, prior = (_,_) -> 1.0, score=UniformScore(), σ = 0.0, ρ = 1.0, κ = n - 1, baseline = 0.0, iterations = min(3*n^2, 50000), schedule=(expcoldness, Dexpcoldness), target=1e10, threshold=Inf, keep=1.0, force=1.0) #, verbose = false, save = true) + if κ >= n + κ = n - 1 + @warn "Truncate κ to $κ" + end + coldness, Dcoldness = schedule + + initscoreval = score_dag(SimpleDiGraph(n), score) + bestgraph = DiGraph(n) + bestscore = initscoreval + + # init M samplers + samplers = [Sample(G[1], 0.0, 1, G[2], initscoreval) for _ = 1:M] # pass correct initial score?! + queue = PriorityQueue{Action, Float64}() + + for i = 1:M + action = sampleaction(samplers, i, M, balance, prior, score, bestscore, σ, ρ, κ, coldness, Dcoldness, threshold, keep, force) + enqueue!(queue, action, action.τ) + end + + # todo: multiply iterations by M to keep passed iteration number indep of M? + iterations *= M + + count = 0 + particles = M + t = 0.0 + β = schedule[1](t) + pr = Progress(iterations) + iter = 1 + while iter <= iterations + action = dequeue!(queue) + t = action.τ + β = schedule[1](t) + β > target && break + next!(pr; showvalues = [(:M,particles), (:t, round(t, sigdigits=6)), (:score, bestscore), (:temp, round(β, sigdigits=6))]) + + count += (action.apply! == applycopy) || (action.apply! == applykill) + if action.apply! == applykill + particles -= 1 + end + if action.apply! != applyflip # flips are free + iter += 1 + end + + nextsample = action.apply!(samplers, action.i, action.τ, action.args...) + particles == 0 && break + + if nextsample.alive && nextsample.scoreval > bestscore + bestgraph = nextsample.g + bestscore = nextsample.scoreval + end + action = sampleaction(samplers, action.i, M, balance, prior, score, bestscore, σ, ρ, κ, coldness, Dcoldness, threshold, keep, force) + enqueue!(queue, action, action.τ) + end + finish!(pr) + killratio = count/iterations + + @show particles killratio t β + + return bestgraph, bestscore, [sample for sample in samplers if sample.alive] +end diff --git a/test/multisampler.jl b/test/multisampler.jl new file mode 100644 index 00000000..fa0999a0 --- /dev/null +++ b/test/multisampler.jl @@ -0,0 +1,123 @@ +using Random, CausalInference, StatsBase, Statistics, Test, Graphs, LinearAlgebra +@testset "MultiSampler" begin + Random.seed!(1) + + N = 400 # number of data points + + # define simple linear model with added noise + x = randn(N) + v = x + randn(N)*0.25 + w = x + randn(N)*0.25 + z = v + w + randn(N)*0.25 + s = z + randn(N)*0.25 + + df = (x=x, v=v, w=w, z=z, s=s) + iterations = 1_000 + penalty = 2.0 # increase to get more edges in truth + n = length(df) # vertices + Random.seed!(101) + C = cor(CausalInference.Tables.matrix(df)) + score = GaussianScore(C, N, penalty) + decay = 5.0e-5 + schedule = (τ -> 1.0 + τ*decay, τ -> decay) # linear + M = 20 + baseline = 0.0 + bestgraph, bestscore, samplers = CausalInference.multisampler(n; M, score, baseline, schedule, iterations, keep=0.5, force=0.1) + #posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true) + + # maximum aposteriori estimate + MAP = [1=>2, 1=>3, 2=>1, 2=>4, 3=>1, 3=>4, 4=>5] + @test bestgraph == digraph(MAP, n) + #cm = sort(countmap(vpairs.(getfield.(samplers, :g))), byvalue=true, rev=true) + #@test first(cm).first == MAP +end #testset + +@testset "MultiSampler" begin + Random.seed!(1) + + N = 400 # number of data points + + # define simple linear model with added noise + x = randn(N) + v = x + randn(N)*0.25 + w = x + randn(N)*0.25 + z = v + w + randn(N)*0.25 + s = z + randn(N)*0.25 + + df = (x=x, v=v, w=w, z=z, s=s) + iterations = 1_000 + penalty = 2.0 # increase to get more edges in truth + n = length(df) # vertices + Random.seed!(101) + C = cor(CausalInference.Tables.matrix(df)) + score = GaussianScore(C, N, penalty) + decay = 3e-4 + + schedule = (τ -> 1.0 + τ*decay, τ -> decay) # linear + M = 20 + baseline = 0.0 + balance = CausalInference.sqrt_balance + threshold = Inf + bestgraph, bestscore, samplers = multisampler(n; M, ρ = 1.0, score, balance, baseline, schedule, iterations, threshold) + #posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true) + + # maximum aposteriori estimate + MAP = [1=>2, 1=>3, 2=>1, 2=>4, 3=>1, 3=>4, 4=>5] + @test bestgraph == digraph(MAP, n) + cm = sort(countmap(vpairs.(getfield.(samplers, :g))), byvalue=true, rev=true) + Tmin, T = extrema(getfield.(samplers, :τ)) + @show Tmin T schedule[1](T) + @test first(cm).first == MAP +end + +@testset "MultiSampler" begin + Random.seed!(1) + decay = 2e-5 + schedule = (τ -> 0.8 + τ*decay, τ -> decay) # linear + + N = 200 # number of data points + + # define simple linear model with added noise + x = randn(N) + v = x + randn(N)*0.25 + w = x + randn(N)*0.25 + z = v + w + randn(N)*0.25 + s = z + randn(N)*0.25 + + df = (x=x, v=v, w=w, z=z, s=s) + iterations = 880 + penalty = 2.0 # increase to get more edges in truth + n = length(df) # vertices + Random.seed!(101) + C = cor(CausalInference.Tables.matrix(df)) + score = GaussianScore(C, N, penalty) + M = 100 + bestgraph, bestscore, samplers = multisampler(n; M, score, schedule, iterations, target=1.2) + Tmin, T = extrema(getfield.(samplers, :τ)) + coldness = schedule[1](T) + @show Tmin T coldness + + gs = causalzigzag(n; score, κ=n-1, ρ=10.0, coldness, iterations=iterations*100) + graphs, graph_pairs, hs, τs, ws, ts, scores = CausalInference.unzipgs(gs) + posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true) + + + # maximum aposteriori estimate + MAP = [1=>2, 1=>3, 2=>1, 2=>4, 3=>1, 3=>4, 4=>5] + @test bestgraph == digraph(MAP, n) + cm = sort((proportionmap(vpairs.(getfield.(samplers, :g)))), byvalue=true, rev=true) + @test first(cm).first == MAP + logΠ = map(g->score_dag(pdag2dag!(digraph(g, n)), score), collect(keys(cm))) + Π = normalize(exp.(coldness*(logΠ .- maximum(logΠ))), 1) + Πhat = normalize(collect(values(cm)), 1) + + display([Π Πhat]) + s = 0.0 + for (i, k) in enumerate(keys(cm)) + s += get(posterior, k, 0.0) + #@show cm[k] Π[i] + end + @show s + @test s > 0.99 + @test norm(collect(values(cm)) - Π) < 0.02 +end #testset diff --git a/test/runtests.jl b/test/runtests.jl index f7376de5..339cfa7e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,3 +18,4 @@ include("witness.jl") include("fci.jl") include("klentropy.jl") include("backdoor.jl") +include("multisampler.jl")