@@ -7,8 +7,8 @@ module AdaptiveRejectionSampling
77# ------------------------------
88using Random # Random stdlib
99# ------------------------------
10- using ForwardDiff # For automatic differentiation, no user nor approximate derivatives
11- using StatsBase # To include the basic sample from array function
10+ import ForwardDiff: derivative
11+ import StatsBase: sample, weights
1212# ------------------------------
1313export Line, Objective, Envelop, RejectionSampler # Structures/classes
1414export run_sampler!, sample_envelop, eval_envelop, add_segment! # Methods
@@ -29,7 +29,7 @@ Finds the horizontal coordinate of the intersection between lines
2929"""
3030function intersection (l1:: Line , l2:: Line )
3131 @assert l1. slope != l2. slope " slopes should be different"
32- - (l2. intercept - l1. intercept) / (l2. slope - l1. slope)
32+ - (l2. intercept - l1. intercept) / (l2. slope - l1. slope)
3333end
3434
3535"""
@@ -38,9 +38,19 @@ Computes the integral
3838 ``LaTeX \\ int_{x_1} ^ {x_2} \\ exp\\ {ax + b\\ } dx. ``
3939The resulting value is the weight assigned to the segment [x1, x2] in the envelop
4040"""
41- function exp_integral (l:: Line , x1:: Float64 , x2:: Float64 )
41+ @inline function exp_integral (l:: Line , x1:: Float64 , x2:: Float64 )
4242 a, b = l. slope, l. intercept
43- exp (b) * (exp (a * x2) - exp (a * x1)) / a
43+ v1, v2 = a* x1, a* x2
44+ if v1 > 25.0 || v2 > 25.0 || a == 0.0 || b > 25.0
45+ @warn " exp_integral: numerical instability, truncating, check for under/overflow, consider truncating logf"
46+ v1 = min (v1, 25.0 )
47+ v2 = min (v2, 25.0 )
48+ b = min (b, 25.0 )
49+ if a == 0
50+ a = (v2 - v1) * 1e-6
51+ end
52+ end
53+ exp (b) * (exp (v2) - exp (v1)) / a
4454end
4555
4656"""
@@ -56,13 +66,13 @@ mutable struct Envelop
5666 weights:: AbstractVector{Float64}
5767 size:: Int
5868
59- Envelop (lines:: Vector{Line} , support:: Tuple{Float64, Float64} ) = begin
60- @assert issorted ([l. slope for l in lines], rev = true ) " line slopes must be decreasing"
61- intersections = [intersection (lines[i], lines[i + 1 ]) for i in 1 : (length (lines) - 1 )]
69+ Envelop (lines:: Vector{Line} , support:: Tuple{Float64,Float64} ) = begin
70+ @assert issorted ([l. slope for l in lines], rev= true ) " line slopes must be decreasing"
71+ intersections = [intersection (lines[i], lines[i+ 1 ]) for i in 1 : (length (lines)- 1 )]
6272 cutpoints = [support[1 ]; intersections; support[2 ]]
6373 @assert issorted (cutpoints) " cutpoints must be ordered"
6474 @assert length (unique (cutpoints)) == length (cutpoints) " cutpoints can't have duplicates"
65- weights = [exp_integral (l, cutpoints[i], cutpoints[i + 1 ]) for (i, l) in enumerate (lines)]
75+ weights = [exp_integral (l, cutpoints[i], cutpoints[i+ 1 ]) for (i, l) in enumerate (lines)]
6676 @assert Inf ∉ weights " Overflow in assigning weights"
6777 new (lines, cutpoints, weights, length (lines))
6878 end
@@ -84,19 +94,19 @@ function add_segment!(e::Envelop, l::Line)
8494 # Insert in second position, first one is the support bound
8595 insert! (e. cutpoints, pos + 1 , new_cut)
8696 elseif pos == e. size + 1
87- new_cut = intersection (l, e. lines[pos - 1 ])
97+ new_cut = intersection (l, e. lines[pos- 1 ])
8898 insert! (e. cutpoints, pos, new_cut)
8999 else
90- new_cut1 = intersection (l, e. lines[pos - 1 ])
100+ new_cut1 = intersection (l, e. lines[pos- 1 ])
91101 new_cut2 = intersection (l, e. lines[pos])
92102 splice! (e. cutpoints, pos, [new_cut1, new_cut2])
93- @assert issorted (e. cutpoints) " incompatible line: resulting intersection points aren't sorted"
103+ @assert issorted (e. cutpoints) " incompatible line: resulting intersection points aren't sorted"
94104 end
95105 # Insert the new line
96106 insert! (e. lines, pos, l)
97107 e. size += 1
98108 # Recompute weights (this could be done more efficiently in the future by updating the neccesary ones only)
99- e. weights = [exp_integral (line, e. cutpoints[i], e. cutpoints[i + 1 ]) for (i, line) in enumerate (e. lines)]
109+ e. weights = [exp_integral (line, e. cutpoints[i], e. cutpoints[i+ 1 ]) for (i, line) in enumerate (e. lines)]
100110end
101111
102112"""
@@ -129,7 +139,7 @@ function eval_envelop(e::Envelop, x::Float64)
129139 if pos == 1 || pos == length (e. cutpoints) + 1
130140 return 0.0
131141 else
132- a, b = e. lines[pos - 1 ]. slope, e. lines[pos - 1 ]. intercept
142+ a, b = e. lines[pos- 1 ]. slope, e. lines[pos- 1 ]. intercept
133143 return exp (a * x + b)
134144 end
135145end
@@ -147,27 +157,37 @@ struct Objective
147157 grad:: Function
148158 Objective (logf:: Function ) = begin
149159 # Automatic differentiation
150- grad (x) = ForwardDiff . derivative (logf, x)
160+ grad (x) = derivative (logf, x)
151161 new (logf, grad)
152162 end
153163 Objective (logf:: Function , grad:: Function ) = new (logf, grad)
154164end
155165
156166"""
167+ RejectionSampler(f::Function, support::Tuple{Float64, Float64}, init::Tuple{Float64, Float64})
157168 RejectionSampler(f::Function, support::Tuple{Float64, Float64}[ ,δ::Float64])
158- RejectionSampler(f::Function, support::Tuple{Float64, Float64}, init::Tuple{Float64, Float64})
159- An adaptive rejection sampler to obtain iid samples from a logconcave function `f`, supported in the
160- domain `support` = (support[1], support[2]). To create the object, two initial points `init = init[1], init[2]`
161- such that `loff'(init[1]) > 0` and `logf'(init[2]) < 0` are necessary. If they are not provided, the constructor
162- will perform a greedy search based on `δ`.
163-
164- The argument `support` must be of the form `(-Inf, Inf), (-Inf, a), (b, Inf), (a,b)`, and it represent the
169+ An adaptive rejection sampler to obtain iid samples from a logconcave function supported in
170+ `support = (support[1], support[2])`. f can either be the either probability density of the
171+ function to be sampled, or its logarithm. For the latter, use the keyword argument `log=true`.
172+ The functions can be unnormalized in the sense that the probability density can be specified up to a constant.
173+ The adaptive rejection samplings algorithm requires two initial points `init = init[1], init[2]`
174+ such that (d/dx)logp(init[1]) > 0 and (d/dx)logp(init[2]) < 0. These points can be provided directly
175+ (typically, any point left and right of the mode will do). It is also possibe to specify and search
176+ range and delta for a greedy search of the initial points.
177+ The `support` must be of the form `(-Inf, Inf), (-Inf, a), (b, Inf), (a,b)`, and it represent the
165178interval in which f has positive value, and zero elsewhere.
166179
180+ The alternative constructor uses a search_range, a value δ for the distance between points in the search,
181+ and min/max slope values in absolute terms.
182+
167183## Keyword arguments
168184- `max_segments::Int = 10` : max size of envelop, the rejection-rate is usually slow with a small number of segments
169185- `max_failed_factor::Float64 = 0.001`: level at which throw an error if one single sample has a rejection rate
170186 exceeding this value
187+ - `logdensity::Bool = false`: indicator fo whether `f` is the log of the probability density, up to a normalization constant.
188+ - `search_range::Tuple{Float64,Float64} = (-10.0, 10.0)`: range in which to search for initial points
189+ - `min_slope::Float64 = 1e-6`: minimum slope in absolute value of logf at the initial/found points
190+ - `max_slope::Float64 = Inf: maximum slope in absolute value of logf at the initial/found points
171191"""
172192struct RejectionSampler
173193 objective:: Objective
@@ -176,42 +196,56 @@ struct RejectionSampler
176196 max_failed_rate:: Float64
177197 # Constructor when initial points are provided
178198 RejectionSampler (
179- f:: Function ,
180- support:: Tuple{Float64, Float64} ,
181- init:: Tuple{Float64, Float64} ;
182- max_segments:: Int = 25 ,
183- max_failed_rate:: Float64 = 0.001
199+ f:: Function ,
200+ support:: Tuple{Float64,Float64} ,
201+ init:: Tuple{Float64,Float64} ;
202+ max_segments:: Int = 25 ,
203+ logdensity:: Bool = false ,
204+ max_failed_rate:: Float64 = 0.001 ,
184205 ) = begin
185206 @assert support[1 ] < support[2 ] " invalid support, not an interval"
186- logf (x) = log (f (x))
187- objective = Objective (logf)
207+ if logdensity
208+ objective = Objective (f)
209+ else
210+ objective = Objective (x -> log (f (x)))
211+ end
188212 x1, x2 = init
189213 @assert x1 < x2 " cutpoints must be ordered"
190214 a1, a2 = objective. grad (x1), objective. grad (x2)
191- @assert a1 >= 0 " logf must have positive slope at initial cutpoint 1"
192- @assert a2 <= 0 " logf must have negative slope at initial cutpoint 2"
215+ @assert 0.0 < a1 " logf must have positive slope at initial cutpoint 1"
216+ @assert a2 < 0. 0 " logf must have negative slope at initial cutpoint 2"
193217 b1, b2 = objective. logf (x1) - a1 * x1, objective. logf (x2) - a2 * x2
194218 line1, line2 = Line (a1, b1), Line (a2, b2)
195219 envelop = Envelop ([line1, line2], support)
196220 new (objective, envelop, max_segments, max_failed_rate)
197221 end
198-
199- # Constructor for greedy search of starting points
222+ " "
200223 RejectionSampler (
201- f:: Function ,
202- support:: Tuple{Float64, Float64} ,
203- δ:: Float64 = 0.5 ;
204- search_range:: Tuple{Float64, Float64} = (- 10.0 ,10.0 ),
205- kwargs...
224+ f:: Function ,
225+ support:: Tuple{Float64,Float64} ,
226+ δ:: Float64 = 0.5 ;
227+ search_range:: Tuple{Float64,Float64} = (- 10.0 , 10.0 ),
228+ min_slope:: Float64 = 1e-6 ,
229+ max_slope:: Float64 = 10.0 ,
230+ logdensity:: Bool = false ,
231+ kwargs...
206232 ) = begin
207- logf (x) = log (f (x))
208- grad (x) = ForwardDiff. derivative (logf, x)
233+ if logdensity
234+ logf = f
235+ else
236+ logf = (x -> log (f (x)))
237+ end
238+ grad (x) = derivative (logf, x)
209239 grid_lims = max (search_range[1 ], support[1 ]), min (search_range[2 ], support[2 ])
210240 grid = grid_lims[1 ]: δ:grid_lims [2 ]
211- i1, i2 = findfirst (grad .(grid) .> 0. ), findfirst (grad .(grid) .< 0. )
212- @assert (i1 != nothing ) && (i2 != nothing ) " couldn't find initial points, please provide them or change `search_range`"
241+ grads = grad .(grid)
242+ i1 = findfirst (min_slope .< grads .< max_slope)
243+ i2 = findlast (min_slope .< - grads .< max_slope)
244+ @assert (i1 != = nothing ) && (i2 != = nothing ) " couldn't find initial points, please provide them or change `search_range`"
245+ @assert i1 < i2 " function is not logconcave, first index with grad>0 must be smaller than first index with grad<0"
213246 x1, x2 = grid[i1], grid[i2]
214- RejectionSampler (f, support, (x1, x2); kwargs... )
247+ @info " initial points found at $(x1) , $(x2) with grads $(grads[i1]) , $(grads[i2]) "
248+ RejectionSampler (f, support, (x1, x2); logdensity= logdensity, kwargs... )
215249 end
216250end
217251
0 commit comments