11"""
2- BayesianNetwork
2+ BayesianNetwork
33
44A structure representing a Bayesian Network.
55"""
@@ -49,7 +49,7 @@ function BayesianNetwork{V}() where {V}
4949end
5050
5151"""
52- translate_BUGSGraph_to_BayesianNetwork(g::MetaGraph; init=Dict{Symbol,Any}())
52+ translate_BUGSGraph_to_BayesianNetwork(g::MetaGraph; init=Dict{Symbol,Any}())
5353
5454Translates a BUGSGraph (with node metadata stored in NodeInfo) into a BayesianNetwork.
5555"""
@@ -77,9 +77,7 @@ function translate_BUGSGraph_to_BayesianNetwork(
7777
7878 if model != = nothing
7979 if isdefined (model, :transformed_var_lengths )
80- for (k, v) in pairs (model. transformed_var_lengths)
81- transformed_var_lengths[k] = v
82- end
80+ transformed_var_lengths = copy (model. transformed_var_lengths)
8381 end
8482 if isdefined (model, :transformed_param_length )
8583 transformed_param_length = model. transformed_param_length
@@ -135,7 +133,7 @@ function translate_BUGSGraph_to_BayesianNetwork(
135133end
136134
137135"""
138- add_stochastic_vertex!(bn::BayesianNetwork{V,T}, name::V, dist::Any, node_type::Symbol; is_observed::Bool=false) where {V,T}
136+ add_stochastic_vertex!(bn::BayesianNetwork{V,T}, name::V, dist::Any, node_type::Symbol; is_observed::Bool=false) where {V,T}
139137
140138Add a stochastic vertex with name `name`, a distribution object/function `dist`,
141139and a declared node_type (`:discrete` or `:continuous`).
@@ -160,7 +158,7 @@ function add_stochastic_vertex!(
160158end
161159
162160"""
163- add_deterministic_vertex!(bn::BayesianNetwork{V,T}, name::V, f::F) where {T,V,F}
161+ add_deterministic_vertex!(bn::BayesianNetwork{V,T}, name::V, f::F) where {T,V,F}
164162
165163Add a deterministic vertex.
166164"""
@@ -178,7 +176,7 @@ function add_deterministic_vertex!(bn::BayesianNetwork{V,T}, name::V, f::F)::T w
178176end
179177
180178"""
181- add_edge!(bn::BayesianNetwork{V,T}, from::V, to::V) where {T,V}
179+ add_edge!(bn::BayesianNetwork{V,T}, from::V, to::V) where {T,V}
182180
183181Add a directed edge from `from` -> `to`.
184182"""
@@ -261,3 +259,177 @@ function evaluate_with_values(bn::BayesianNetwork, parameter_values::AbstractVec
261259
262260 return evaluation_env, logprior + loglikelihood
263261end
262+
263+ function evaluate_with_marginalization (
264+ bn:: BayesianNetwork{V,T,F} , parameter_values:: AbstractVector
265+ ) where {V,T,F}
266+ # Get topological ordering of nodes
267+ sorted_node_ids = topological_sort_by_dfs (bn. graph)
268+
269+ # Find continuous variables (all stochastic unobserved variables that are not discrete)
270+ continuous_vars = [
271+ bn. names[i] for i in sorted_node_ids if
272+ bn. is_stochastic[i] && ! bn. is_observed[i] && bn. node_types[i] != :discrete
273+ ]
274+
275+ # Calculate total parameter length needed
276+ total_param_length = 0
277+ for name in continuous_vars
278+ if haskey (bn. transformed_var_lengths, name)
279+ total_param_length += bn. transformed_var_lengths[name]
280+ end
281+ end
282+
283+ # No discrete variables case - use standard evaluation
284+ discrete_vars = [
285+ bn. names[i] for i in sorted_node_ids if
286+ bn. is_stochastic[i] && ! bn. is_observed[i] && bn. node_types[i] == :discrete
287+ ]
288+
289+ if isempty (discrete_vars)
290+ return evaluate_with_values (bn, parameter_values)
291+ end
292+
293+ # Initialize environment once
294+ env = deepcopy (bn. evaluation_env)
295+
296+ # Start recursive evaluation with the first node, beginning at parameter index 1
297+ logp = _marginalize_recursive (
298+ bn, env, sorted_node_ids, parameter_values, 1 , bn. transformed_var_lengths
299+ )
300+
301+ return env, logp
302+ end
303+
304+ function _marginalize_recursive (
305+ bn:: BayesianNetwork{V,T,F} ,
306+ env,
307+ remaining_nodes,
308+ parameter_values:: AbstractVector ,
309+ param_idx:: Int ,
310+ var_lengths,
311+ ) where {V,T,F}
312+ # Base case: no more nodes to process
313+ if isempty (remaining_nodes)
314+ return 0.0
315+ end
316+
317+ # Process current node
318+ current_id = remaining_nodes[1 ]
319+ current_name = bn. names[current_id]
320+
321+ # Check node type
322+ is_stochastic = bn. is_stochastic[current_id]
323+ is_observed = bn. is_observed[current_id]
324+ is_discrete = bn. node_types[current_id] == :discrete
325+
326+ if ! is_stochastic
327+ # Deterministic node - compute value and continue
328+ value = bn. deterministic_functions[current_id](env, bn. loop_vars[current_name])
329+ env = BangBang. setindex!! (env, value, current_name)
330+ return _marginalize_recursive (
331+ bn, env, @view (remaining_nodes[2 : end ]), parameter_values, param_idx, var_lengths
332+ )
333+
334+ elseif is_observed
335+ # Observed node - add log probability and continue
336+ dist = bn. distributions[current_id](env, bn. loop_vars[current_name])
337+ obs_logp = logpdf (dist, AbstractPPL. get (env, current_name))
338+ remaining_logp = _marginalize_recursive (
339+ bn, env, @view (remaining_nodes[2 : end ]), parameter_values, param_idx, var_lengths
340+ )
341+ return obs_logp + remaining_logp
342+
343+ elseif is_discrete
344+ # Discrete unobserved node - marginalize over possible values
345+ dist = bn. distributions[current_id](env, bn. loop_vars[current_name])
346+ possible_values = enumerate_discrete_values (dist)
347+
348+ # Collect log probabilities for all possible values
349+ logp_branches = Vector {Float64} (undef, length (possible_values))
350+
351+ for (i, value) in enumerate (possible_values)
352+ # Create a branch-specific environment
353+ branch_env = BangBang. setindex!! (deepcopy (env), value, current_name)
354+
355+ # Compute log probability of this value
356+ value_logp = logpdf (dist, value)
357+
358+ # Continue evaluation with this assignment
359+ # Important: We use the same param_idx for all branches since discrete variables
360+ # don't consume parameters
361+ remaining_logp = _marginalize_recursive (
362+ bn,
363+ branch_env,
364+ @view (remaining_nodes[2 : end ]),
365+ parameter_values,
366+ param_idx,
367+ var_lengths,
368+ )
369+
370+ logp_branches[i] = value_logp + remaining_logp
371+ end
372+
373+ # Marginalize using logsumexp for numerical stability
374+ return LogExpFunctions. logsumexp (logp_branches)
375+
376+ else
377+ # Continuous unobserved node - use parameter values
378+ dist = bn. distributions[current_id](env, bn. loop_vars[current_name])
379+ b = Bijectors. bijector (dist)
380+
381+ # Ensure variable length is in the dictionary
382+ if ! haskey (var_lengths, current_name)
383+ error (
384+ " Missing transformed length for variable '$(current_name) '. All variables should have their transformed lengths pre-computed in JuliaBUGS." ,
385+ )
386+ end
387+
388+ l = var_lengths[current_name]
389+
390+ # Process the continuous variable
391+ b_inv = Bijectors. inverse (b)
392+ param_slice = view (parameter_values, param_idx: (param_idx + l - 1 ))
393+ reconstructed_value = JuliaBUGS. reconstruct (b_inv, dist, param_slice)
394+ value, logjac = Bijectors. with_logabsdet_jacobian (b_inv, reconstructed_value)
395+
396+ # Update environment
397+ env = BangBang. setindex!! (env, value, current_name)
398+
399+ # Compute log probability and continue with updated parameter index
400+ dist_logp = logpdf (dist, value) + logjac
401+ next_idx = param_idx + l
402+ remaining_logp = _marginalize_recursive (
403+ bn, env, @view (remaining_nodes[2 : end ]), parameter_values, next_idx, var_lengths
404+ )
405+
406+ return dist_logp + remaining_logp
407+ end
408+ end
409+
410+ """
411+ enumerate_discrete_values(dist)
412+
413+ Return all possible values for a discrete distribution.
414+ Currently supports Categorical, Bernoulli, Binomial, and DiscreteUniform distributions.
415+ """
416+ function enumerate_discrete_values (dist:: DiscreteUnivariateDistribution )
417+ if dist isa Categorical
418+ return 1 : length (dist. p)
419+ elseif dist isa Bernoulli
420+ return [0 , 1 ]
421+ elseif dist isa Binomial
422+ # Handle special case where n is 0
423+ if dist. n == 0
424+ return 0 : 0
425+ else
426+ return 0 : (dist. n)
427+ end
428+ elseif dist isa DiscreteUniform
429+ return (dist. a): (dist. b)
430+ else
431+ error (
432+ " Distribution type $(typeof (dist)) is not currently supported for discrete marginalization" ,
433+ )
434+ end
435+ end
0 commit comments