Skip to content
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
38e4efe
add auto-marginalization implementation to the main package
sunxd3 Aug 23, 2025
a5aeb6b
format
sunxd3 Aug 23, 2025
b4eb8d5
add import
sunxd3 Aug 23, 2025
fbfac76
import Exponential distribution
sunxd3 Aug 23, 2025
3c0504e
using LogExpFunctions in tests
sunxd3 Aug 23, 2025
6b1206e
test on GMM
sunxd3 Aug 23, 2025
7243c23
fix some errors
sunxd3 Aug 26, 2025
2bc9e4b
add hmm sanity check
sunxd3 Aug 28, 2025
ad90ec8
format some example files
sunxd3 Aug 28, 2025
6de0807
fix error
sunxd3 Aug 29, 2025
0445a28
add sampling tests
sunxd3 Aug 29, 2025
f06e2f1
formatting
sunxd3 Aug 29, 2025
184434c
Update JuliaBUGS/test/model/auto_marginalization_sampling.jl
sunxd3 Aug 29, 2025
bc107b2
Update JuliaBUGS/test/model/auto_marginalization.jl
sunxd3 Aug 29, 2025
92a907b
Update JuliaBUGS/test/model/auto_marginalization.jl
sunxd3 Aug 29, 2025
434afd1
Merge branch 'main' into sunxd/auto_marginalization
sunxd3 Aug 29, 2025
e812006
fix test error
sunxd3 Aug 29, 2025
618c23e
stop using `invokelatest`
sunxd3 Sep 1, 2025
7eb5224
fix test errors
sunxd3 Sep 1, 2025
f441205
Merge branch 'main' into sunxd/auto_marginalization
sunxd3 Sep 1, 2025
7b8cedd
fix performance by moving more computation to the construction
sunxd3 Sep 1, 2025
2d66d1c
Update bugsmodel.jl
sunxd3 Sep 2, 2025
5f00a14
add experiment package
sunxd3 Sep 14, 2025
4ea7726
example
sunxd3 Sep 23, 2025
c52e5e0
remove experiments
sunxd3 Sep 23, 2025
7a38695
chore: empty commit [skip ci]
sunxd3 Sep 23, 2025
1d229ca
chore: revert example formatting changes
sunxd3 Sep 23, 2025
f7a044f
move the auto-marg doc into JuliaBUGS docs folder
sunxd3 Sep 23, 2025
6d2bc02
rename the auto marg doc
sunxd3 Sep 23, 2025
f6002b3
Merge branch 'main' into sunxd/auto_marginalization
sunxd3 Sep 23, 2025
d8874e4
fix: stabilize auto-marginalization caches and tempering
sunxd3 Sep 23, 2025
cf180e3
add experiment code
sunxd3 Sep 24, 2025
f62ff91
check in all the experiment code
sunxd3 Sep 30, 2025
2d7a51b
update scripts; remove results
sunxd3 Sep 30, 2025
aec0159
update plan
sunxd3 Sep 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions JuliaBUGS/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
style="blue"
always_use_return=false
short_to_long_function_def=false
2 changes: 2 additions & 0 deletions JuliaBUGS/src/BUGSExamples/Volume_2/08_Cervix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ model
}
"""

#! format: off
data = (N = 2044,
d = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
Expand Down Expand Up @@ -442,6 +443,7 @@ data = (N = 2044,

inits = (beta0C = 0, beta = 0, q = 0.5, phi = [0.5 0.5; 0.5 0.5])
inits_alternative = (beta0C = 1.0, beta = 1.0, q = 0.75, phi = [0.15 0.15; 0.15 0.15])
#! format: on

reference_results = (
var"beta" = (mean = 0.6156, std = 0.3406),
Expand Down
2 changes: 2 additions & 0 deletions JuliaBUGS/src/BUGSExamples/Volume_2/11_Schools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ model
}
"""

#! format: off
data = (N = 1978, M = 38, mn = [0, 0, 0],
prec = [0.0001 0 0; 0 0.0001 0; 0 0 0.0001],
R = [0.1 0.005 0.005; 0.005 0.1 0.005; 0.005 0.005 0.1],
Expand Down Expand Up @@ -1573,6 +1574,7 @@ inits_alternative = (theta = 0.1, phi = 0, gamma = [1.0, 1.0, 1.0],
0 0 0
0 0 0
0 0 0])
#! format: on

reference_results = (
var"beta[1]" = (mean = 2.589E-4, std = 9.8E-5),
Expand Down
2 changes: 1 addition & 1 deletion JuliaBUGS/src/BUGSExamples/Volume_2/16_Stagnant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ reference_results = (
)

stagnant = Example(
name, model_def, original, data, inits, inits_alternative, reference_results)
name, model_def, original, data, inits, inits_alternative, reference_results)
1 change: 1 addition & 0 deletions JuliaBUGS/src/BUGSExamples/Volume_2/18_Pigs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

1 change: 1 addition & 0 deletions JuliaBUGS/src/BUGSExamples/Volume_2/19_Simulating_data.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

2 changes: 1 addition & 1 deletion JuliaBUGS/src/BUGSExamples/Volume_3/09_Hips4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ model_def = @bugs begin
for t in 2:N
for s in 1:S
var"pi"[
n, k, t, s] = inprod(
n, k, t, s] = inprod(
var"pi"[n, k, t - 1, :], Lambda[n, k, t, :, s])
end
end
Expand Down
3 changes: 2 additions & 1 deletion JuliaBUGS/src/BUGSExamples/Volume_4/11_Methadone.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ end
data_dict = JSON.parsefile(joinpath(readdir(), "methadone_data.json"))
data = NamedTuple{Tuple([Symbol(key) for key in keys(data_dict)])}(Tuple([map(identity,
val)
for val in values(data_dict)]))
for val in
values(data_dict)]))

inits = (
lambda = 0,
Expand Down
3 changes: 2 additions & 1 deletion JuliaBUGS/src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ using .Model:
BUGSModel,
evaluate_with_values!!,
UseGraph,
UseGeneratedLogDensityFunction
UseGeneratedLogDensityFunction,
UseAutoMarginalization

include("independent_mh.jl")
include("gibbs.jl")
Expand Down
4 changes: 4 additions & 0 deletions JuliaBUGS/src/model/Model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Graphs
using LinearAlgebra
using JuliaBUGS: JuliaBUGS, BUGSGraph
using JuliaBUGS.BUGSPrimitives
using LogExpFunctions
using MetaGraphsNext
using Random

Expand All @@ -20,5 +21,8 @@ include("logdensityproblems.jl")

export parameters, variables, initialize!, getparams, settrans, set_evaluation_mode
export evaluate_with_rng!!, evaluate_with_env!!, evaluate_with_values!!
export evaluate_with_marginalization_rng!!,
evaluate_with_marginalization_env!!, evaluate_with_marginalization_values!!
export UseAutoMarginalization, enumerate_discrete_values

end # Model
42 changes: 37 additions & 5 deletions JuliaBUGS/src/model/abstractppl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ function _create_modified_model(
# Recompute mutable symbols for the new graph
new_mutable_symbols = get_mutable_symbols(updated_graph_evaluation_data)

# Create the new model with all updated fields
# Create the new model with all updated fields (without auto-marg caches yet)
kwargs = Dict{Symbol,Any}(
:untransformed_param_length => new_untransformed_param_length,
:transformed_param_length => new_transformed_param_length,
Expand All @@ -585,7 +585,33 @@ function _create_modified_model(
kwargs[:base_model] = base_model
end

return BUGSModel(model; kwargs...)
new_model = BUGSModel(model; kwargs...)

# Compute and attach auto-marg caches once for the new graph
try
order = JuliaBUGS.Model._compute_marginalization_order(new_model)
keys = JuliaBUGS.Model._precompute_minimal_cache_keys(new_model, order)

gd = new_model.graph_evaluation_data
gd_cached = GraphEvaluationData{
typeof(gd.node_function_vals),typeof(gd.loop_vars_vals)
}(
gd.sorted_nodes,
gd.sorted_parameters,
gd.is_stochastic_vals,
gd.is_observed_vals,
gd.node_function_vals,
gd.loop_vars_vals,
gd.node_types,
gd.is_discrete_finite_vals,
keys,
order,
)
return BUGSModel(new_model; graph_evaluation_data=gd_cached)
catch
# If caches cannot be computed (e.g., unsupported model), return model as-is.
return new_model
end
end

# Common helper function to regenerate log density function
Expand Down Expand Up @@ -698,8 +724,14 @@ function evaluate!!(
temperature=1.0,
transformed=model.transformed,
)
evaluation_env, log_densities = evaluate_with_values!!(
model, flattened_values; temperature=temperature, transformed=transformed
)
if model.evaluation_mode isa UseAutoMarginalization
evaluation_env, log_densities = evaluate_with_marginalization_values!!(
model, flattened_values; temperature=temperature, transformed=transformed
)
else
evaluation_env, log_densities = evaluate_with_values!!(
model, flattened_values; temperature=temperature, transformed=transformed
)
end
return evaluation_env, log_densities.tempered_logjoint
end
Loading