Skip to content

Commit 65a41a9

Browse files
committed
add multithreading to importance sampling
1 parent a4952de commit 65a41a9

File tree

1 file changed

+51
-16
lines changed

1 file changed

+51
-16
lines changed

src/inference/importance.jl

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,40 +17,75 @@ The second variant uses a custom proposal distribution defined by the given gene
1717
All addresses of random choices sampled by the proposal should also be sampled by the model function.
1818
Setting `verbose=true` prints a progress message every sample.
1919
"""
20-
function importance_sampling(model::GenerativeFunction{T,U}, model_args::Tuple,
21-
observations::ChoiceMap,
22-
num_samples::Int, verbose=false) where {T,U}
20+
function importance_sampling(
21+
model::GenerativeFunction{T,U}, model_args::Tuple,
22+
observations::ChoiceMap, num_samples::Int;
23+
verbose=false, multithreaded=false) where {T,U}
2324
traces = Vector{U}(undef, num_samples)
2425
log_weights = Vector{Float64}(undef, num_samples)
25-
for i=1:num_samples
26-
verbose && println("sample: $i of $num_samples")
27-
(traces[i], log_weights[i]) = generate(model, model_args, observations)
26+
if multithreaded
27+
Threads.@threads for i in 1:num_samples
28+
importance_sampling_iter!(traces, log_weights, model, model_args, observations, i, verbose)
29+
end
30+
else
31+
for i=1:num_samples
32+
importance_sampling_iter!(traces, log_weights, model, model_args, observations, i, verbose)
33+
end
2834
end
2935
log_total_weight = logsumexp(log_weights)
3036
log_ml_estimate = log_total_weight - log(num_samples)
3137
log_normalized_weights = log_weights .- log_total_weight
3238
return (traces, log_normalized_weights, log_ml_estimate)
3339
end
3440

35-
function importance_sampling(model::GenerativeFunction{T,U}, model_args::Tuple,
36-
observations::ChoiceMap,
37-
proposal::GenerativeFunction, proposal_args::Tuple,
38-
num_samples::Int, verbose=false) where {T,U}
41+
function importance_sampling_iter!(
42+
traces::Vector, log_weights::Vector{Float64},
43+
model::GenerativeFunction, model_args::Tuple,
44+
observations::ChoiceMap, i::Int, verbose::Bool)
45+
(traces[i], log_weights[i]) = generate(model, model_args, observations)
46+
verbose && Core.println("sample: $i of $num_samples completed in thread $(Threads.threadid())")
47+
return nothing
48+
end
49+
50+
function importance_sampling(
51+
model::GenerativeFunction{T,U}, model_args::Tuple,
52+
observations::ChoiceMap, proposal::GenerativeFunction, proposal_args::Tuple,
53+
num_samples::Int; verbose=false, multithreaded=false) where {T,U}
3954
traces = Vector{U}(undef, num_samples)
4055
log_weights = Vector{Float64}(undef, num_samples)
41-
for i=1:num_samples
42-
verbose && println("sample: $i of $num_samples")
43-
(proposed_choices, proposal_weight, _) = propose(proposal, proposal_args)
44-
constraints = merge(observations, proposed_choices)
45-
(traces[i], model_weight) = generate(model, model_args, constraints)
46-
log_weights[i] = model_weight - proposal_weight
56+
if multithreaded
57+
Threads.@threads for i=1:num_samples
58+
importance_sampling_iter!(
59+
traces, log_weights, model, model_args,
60+
observations, proposal, proposal_args, i, verbose)
61+
end
62+
else
63+
for i=1:num_samples
64+
importance_sampling_iter!(
65+
traces, log_weights, model, model_args,
66+
observations, proposal, proposal_args, i, verbose)
67+
end
4768
end
4869
log_total_weight = logsumexp(log_weights)
4970
log_ml_estimate = log_total_weight - log(num_samples)
5071
log_normalized_weights = log_weights .- log_total_weight
5172
return (traces, log_normalized_weights, log_ml_estimate)
5273
end
5374

75+
function importance_sampling_iter!(
76+
traces::Vector, log_weights::Vector{Float64},
77+
model::GenerativeFunction, model_args::Tuple,
78+
observations::ChoiceMap,
79+
proposal::GenerativeFunction, proposal_args::Tuple,
80+
i::Int, verbose::Bool)
81+
(proposed_choices, proposal_weight, _) = propose(proposal, proposal_args)
82+
constraints = merge(observations, proposed_choices)
83+
(traces[i], model_weight) = generate(model, model_args, constraints)
84+
log_weights[i] = model_weight - proposal_weight
85+
verbose && Core.println("sample $i of $num_samples completed in thread $(Threads.threadid())")
86+
return nothing
87+
end
88+
5489
"""
5590
(trace, lml_est) = importance_resampling(model::GenerativeFunction,
5691
model_args::Tuple, observations::ChoiceMap, num_samples::Int,

0 commit comments

Comments
 (0)