Skip to content

Commit 5b7fbdb

Browse files
naseweisssssgithub-actions[bot]sunxd3
authored
Evaluate Function with Recursive Approach to Discreteness (#289)
- Marginalization implementation (recursion-based) for discrete variables in Bayesian networks - Test case included - There should be another PR with similar implementation but with DP. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Xianda Sun <[email protected]>
1 parent cf3112d commit 5b7fbdb

File tree

3 files changed

+766
-9
lines changed

3 files changed

+766
-9
lines changed

src/experimental/ProbabilisticGraphicalModels/ProbabilisticGraphicalModels.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using JuliaBUGS: BUGSGraph, VarName, NodeInfo
99
using AbstractPPL
1010
using Bijectors: Bijectors
1111
using LinearAlgebra: Cholesky
12+
using LogExpFunctions
1213

1314
include("bayesian_network.jl")
1415
include("conditioning.jl")

src/experimental/ProbabilisticGraphicalModels/bayesian_network.jl

Lines changed: 180 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
BayesianNetwork
2+
BayesianNetwork
33
44
A structure representing a Bayesian Network.
55
"""
@@ -49,7 +49,7 @@ function BayesianNetwork{V}() where {V}
4949
end
5050

5151
"""
52-
translate_BUGSGraph_to_BayesianNetwork(g::MetaGraph; init=Dict{Symbol,Any}())
52+
translate_BUGSGraph_to_BayesianNetwork(g::MetaGraph; init=Dict{Symbol,Any}())
5353
5454
Translates a BUGSGraph (with node metadata stored in NodeInfo) into a BayesianNetwork.
5555
"""
@@ -77,9 +77,7 @@ function translate_BUGSGraph_to_BayesianNetwork(
7777

7878
if model !== nothing
7979
if isdefined(model, :transformed_var_lengths)
80-
for (k, v) in pairs(model.transformed_var_lengths)
81-
transformed_var_lengths[k] = v
82-
end
80+
transformed_var_lengths = copy(model.transformed_var_lengths)
8381
end
8482
if isdefined(model, :transformed_param_length)
8583
transformed_param_length = model.transformed_param_length
@@ -135,7 +133,7 @@ function translate_BUGSGraph_to_BayesianNetwork(
135133
end
136134

137135
"""
138-
add_stochastic_vertex!(bn::BayesianNetwork{V,T}, name::V, dist::Any, node_type::Symbol; is_observed::Bool=false) where {V,T}
136+
add_stochastic_vertex!(bn::BayesianNetwork{V,T}, name::V, dist::Any, node_type::Symbol; is_observed::Bool=false) where {V,T}
139137
140138
Add a stochastic vertex with name `name`, a distribution object/function `dist`,
141139
and a declared node_type (`:discrete` or `:continuous`).
@@ -160,7 +158,7 @@ function add_stochastic_vertex!(
160158
end
161159

162160
"""
163-
add_deterministic_vertex!(bn::BayesianNetwork{V,T}, name::V, f::F) where {T,V,F}
161+
add_deterministic_vertex!(bn::BayesianNetwork{V,T}, name::V, f::F) where {T,V,F}
164162
165163
Add a deterministic vertex.
166164
"""
@@ -178,7 +176,7 @@ function add_deterministic_vertex!(bn::BayesianNetwork{V,T}, name::V, f::F)::T w
178176
end
179177

180178
"""
181-
add_edge!(bn::BayesianNetwork{V,T}, from::V, to::V) where {T,V}
179+
add_edge!(bn::BayesianNetwork{V,T}, from::V, to::V) where {T,V}
182180
183181
Add a directed edge from `from` -> `to`.
184182
"""
@@ -261,3 +259,177 @@ function evaluate_with_values(bn::BayesianNetwork, parameter_values::AbstractVec
261259

262260
return evaluation_env, logprior + loglikelihood
263261
end
262+
263+
function evaluate_with_marginalization(
264+
bn::BayesianNetwork{V,T,F}, parameter_values::AbstractVector
265+
) where {V,T,F}
266+
# Get topological ordering of nodes
267+
sorted_node_ids = topological_sort_by_dfs(bn.graph)
268+
269+
# Find continuous variables (all stochastic unobserved variables that are not discrete)
270+
continuous_vars = [
271+
bn.names[i] for i in sorted_node_ids if
272+
bn.is_stochastic[i] && !bn.is_observed[i] && bn.node_types[i] != :discrete
273+
]
274+
275+
# Calculate total parameter length needed
276+
total_param_length = 0
277+
for name in continuous_vars
278+
if haskey(bn.transformed_var_lengths, name)
279+
total_param_length += bn.transformed_var_lengths[name]
280+
end
281+
end
282+
283+
# No discrete variables case - use standard evaluation
284+
discrete_vars = [
285+
bn.names[i] for i in sorted_node_ids if
286+
bn.is_stochastic[i] && !bn.is_observed[i] && bn.node_types[i] == :discrete
287+
]
288+
289+
if isempty(discrete_vars)
290+
return evaluate_with_values(bn, parameter_values)
291+
end
292+
293+
# Initialize environment once
294+
env = deepcopy(bn.evaluation_env)
295+
296+
# Start recursive evaluation with the first node, beginning at parameter index 1
297+
logp = _marginalize_recursive(
298+
bn, env, sorted_node_ids, parameter_values, 1, bn.transformed_var_lengths
299+
)
300+
301+
return env, logp
302+
end
303+
304+
function _marginalize_recursive(
305+
bn::BayesianNetwork{V,T,F},
306+
env,
307+
remaining_nodes,
308+
parameter_values::AbstractVector,
309+
param_idx::Int,
310+
var_lengths,
311+
) where {V,T,F}
312+
# Base case: no more nodes to process
313+
if isempty(remaining_nodes)
314+
return 0.0
315+
end
316+
317+
# Process current node
318+
current_id = remaining_nodes[1]
319+
current_name = bn.names[current_id]
320+
321+
# Check node type
322+
is_stochastic = bn.is_stochastic[current_id]
323+
is_observed = bn.is_observed[current_id]
324+
is_discrete = bn.node_types[current_id] == :discrete
325+
326+
if !is_stochastic
327+
# Deterministic node - compute value and continue
328+
value = bn.deterministic_functions[current_id](env, bn.loop_vars[current_name])
329+
env = BangBang.setindex!!(env, value, current_name)
330+
return _marginalize_recursive(
331+
bn, env, @view(remaining_nodes[2:end]), parameter_values, param_idx, var_lengths
332+
)
333+
334+
elseif is_observed
335+
# Observed node - add log probability and continue
336+
dist = bn.distributions[current_id](env, bn.loop_vars[current_name])
337+
obs_logp = logpdf(dist, AbstractPPL.get(env, current_name))
338+
remaining_logp = _marginalize_recursive(
339+
bn, env, @view(remaining_nodes[2:end]), parameter_values, param_idx, var_lengths
340+
)
341+
return obs_logp + remaining_logp
342+
343+
elseif is_discrete
344+
# Discrete unobserved node - marginalize over possible values
345+
dist = bn.distributions[current_id](env, bn.loop_vars[current_name])
346+
possible_values = enumerate_discrete_values(dist)
347+
348+
# Collect log probabilities for all possible values
349+
logp_branches = Vector{Float64}(undef, length(possible_values))
350+
351+
for (i, value) in enumerate(possible_values)
352+
# Create a branch-specific environment
353+
branch_env = BangBang.setindex!!(deepcopy(env), value, current_name)
354+
355+
# Compute log probability of this value
356+
value_logp = logpdf(dist, value)
357+
358+
# Continue evaluation with this assignment
359+
# Important: We use the same param_idx for all branches since discrete variables
360+
# don't consume parameters
361+
remaining_logp = _marginalize_recursive(
362+
bn,
363+
branch_env,
364+
@view(remaining_nodes[2:end]),
365+
parameter_values,
366+
param_idx,
367+
var_lengths,
368+
)
369+
370+
logp_branches[i] = value_logp + remaining_logp
371+
end
372+
373+
# Marginalize using logsumexp for numerical stability
374+
return LogExpFunctions.logsumexp(logp_branches)
375+
376+
else
377+
# Continuous unobserved node - use parameter values
378+
dist = bn.distributions[current_id](env, bn.loop_vars[current_name])
379+
b = Bijectors.bijector(dist)
380+
381+
# Ensure variable length is in the dictionary
382+
if !haskey(var_lengths, current_name)
383+
error(
384+
"Missing transformed length for variable '$(current_name)'. All variables should have their transformed lengths pre-computed in JuliaBUGS.",
385+
)
386+
end
387+
388+
l = var_lengths[current_name]
389+
390+
# Process the continuous variable
391+
b_inv = Bijectors.inverse(b)
392+
param_slice = view(parameter_values, param_idx:(param_idx + l - 1))
393+
reconstructed_value = JuliaBUGS.reconstruct(b_inv, dist, param_slice)
394+
value, logjac = Bijectors.with_logabsdet_jacobian(b_inv, reconstructed_value)
395+
396+
# Update environment
397+
env = BangBang.setindex!!(env, value, current_name)
398+
399+
# Compute log probability and continue with updated parameter index
400+
dist_logp = logpdf(dist, value) + logjac
401+
next_idx = param_idx + l
402+
remaining_logp = _marginalize_recursive(
403+
bn, env, @view(remaining_nodes[2:end]), parameter_values, next_idx, var_lengths
404+
)
405+
406+
return dist_logp + remaining_logp
407+
end
408+
end
409+
410+
"""
411+
enumerate_discrete_values(dist)
412+
413+
Return all possible values for a discrete distribution.
414+
Currently supports Categorical, Bernoulli, Binomial, and DiscreteUniform distributions.
415+
"""
416+
function enumerate_discrete_values(dist::DiscreteUnivariateDistribution)
417+
if dist isa Categorical
418+
return 1:length(dist.p)
419+
elseif dist isa Bernoulli
420+
return [0, 1]
421+
elseif dist isa Binomial
422+
# Handle special case where n is 0
423+
if dist.n == 0
424+
return 0:0
425+
else
426+
return 0:(dist.n)
427+
end
428+
elseif dist isa DiscreteUniform
429+
return (dist.a):(dist.b)
430+
else
431+
error(
432+
"Distribution type $(typeof(dist)) is not currently supported for discrete marginalization",
433+
)
434+
end
435+
end

0 commit comments

Comments
 (0)