@@ -17,40 +17,75 @@ The second variant uses a custom proposal distribution defined by the given gene
1717All addresses of random choices sampled by the proposal should also be sampled by the model function.
1818Setting `verbose=true` prints a progress message every sample.
1919"""
20- function importance_sampling (model:: GenerativeFunction{T,U} , model_args:: Tuple ,
21- observations:: ChoiceMap ,
22- num_samples:: Int , verbose= false ) where {T,U}
20+ function importance_sampling (
21+ model:: GenerativeFunction{T,U} , model_args:: Tuple ,
22+ observations:: ChoiceMap , num_samples:: Int ;
23+ verbose= false , multithreaded= false ) where {T,U}
2324 traces = Vector {U} (undef, num_samples)
2425 log_weights = Vector {Float64} (undef, num_samples)
25- for i= 1 : num_samples
26- verbose && println (" sample: $i of $num_samples " )
27- (traces[i], log_weights[i]) = generate (model, model_args, observations)
26+ if multithreaded
27+ Threads. @threads for i in 1 : num_samples
28+ importance_sampling_iter! (traces, log_weights, model, model_args, observations, i, verbose)
29+ end
30+ else
31+ for i= 1 : num_samples
32+ importance_sampling_iter! (traces, log_weights, model, model_args, observations, i, verbose)
33+ end
2834 end
2935 log_total_weight = logsumexp (log_weights)
3036 log_ml_estimate = log_total_weight - log (num_samples)
3137 log_normalized_weights = log_weights .- log_total_weight
3238 return (traces, log_normalized_weights, log_ml_estimate)
3339end
3440
35- function importance_sampling (model:: GenerativeFunction{T,U} , model_args:: Tuple ,
36- observations:: ChoiceMap ,
37- proposal:: GenerativeFunction , proposal_args:: Tuple ,
38- num_samples:: Int , verbose= false ) where {T,U}
41+ function importance_sampling_iter! (
42+ traces:: Vector , log_weights:: Vector{Float64} ,
43+ model:: GenerativeFunction , model_args:: Tuple ,
44+ observations:: ChoiceMap , i:: Int , verbose:: Bool )
45+ (traces[i], log_weights[i]) = generate (model, model_args, observations)
46+ verbose && Core. println (" sample: $i of $num_samples completed in thread $(Threads. threadid ()) " )
47+ return nothing
48+ end
49+
50+ function importance_sampling (
51+ model:: GenerativeFunction{T,U} , model_args:: Tuple ,
52+ observations:: ChoiceMap , proposal:: GenerativeFunction , proposal_args:: Tuple ,
53+ num_samples:: Int ; verbose= false , multithreaded= false ) where {T,U}
3954 traces = Vector {U} (undef, num_samples)
4055 log_weights = Vector {Float64} (undef, num_samples)
41- for i= 1 : num_samples
42- verbose && println (" sample: $i of $num_samples " )
43- (proposed_choices, proposal_weight, _) = propose (proposal, proposal_args)
44- constraints = merge (observations, proposed_choices)
45- (traces[i], model_weight) = generate (model, model_args, constraints)
46- log_weights[i] = model_weight - proposal_weight
56+ if multithreaded
57+ Threads. @threads for i= 1 : num_samples
58+ importance_sampling_iter! (
59+ traces, log_weights, model, model_args,
60+ observations, proposal, proposal_args, i, verbose)
61+ end
62+ else
63+ for i= 1 : num_samples
64+ importance_sampling_iter! (
65+ traces, log_weights, model, model_args,
66+ observations, proposal, proposal_args, i, verbose)
67+ end
4768 end
4869 log_total_weight = logsumexp (log_weights)
4970 log_ml_estimate = log_total_weight - log (num_samples)
5071 log_normalized_weights = log_weights .- log_total_weight
5172 return (traces, log_normalized_weights, log_ml_estimate)
5273end
5374
75+ function importance_sampling_iter! (
76+ traces:: Vector , log_weights:: Vector{Float64} ,
77+ model:: GenerativeFunction , model_args:: Tuple ,
78+ observations:: ChoiceMap ,
79+ proposal:: GenerativeFunction , proposal_args:: Tuple ,
80+ i:: Int , verbose:: Bool )
81+ (proposed_choices, proposal_weight, _) = propose (proposal, proposal_args)
82+ constraints = merge (observations, proposed_choices)
83+ (traces[i], model_weight) = generate (model, model_args, constraints)
84+ log_weights[i] = model_weight - proposal_weight
85+ verbose && Core. println (" sample $i of $num_samples completed in thread $(Threads. threadid ()) " )
86+ return nothing
87+ end
88+
5489"""
5590 (trace, lml_est) = importance_resampling(model::GenerativeFunction,
5691 model_args::Tuple, observations::ChoiceMap, num_samples::Int,
0 commit comments