@@ -9,7 +9,7 @@ using ForwardDiff # For automatic differentiation, no user nor approximate deriv
99using StatsBase # To include the basic sample from array function
1010# ------------------------------
1111export Line, Objective, Envelop, RejectionSampler # Structures/classes
12- export run_sampler!, eval_envelop # Methods
12+ export run_sampler!, eval_envelop, add_segment! # Methods
1313# ------------------------------
1414
1515"""
@@ -57,7 +57,7 @@ mutable struct Envelop
5757 Envelop (lines:: Vector{Line} , support:: Tuple{Float64, Float64} ) = begin
5858 @assert issorted ([l. slope for l in lines], rev = true ) " line slopes must be decreasing"
5959 intersections = [intersection (lines[i], lines[i + 1 ]) for i in 1 : (length (lines) - 1 )]
60- cutpoints = [support[1 ]; intersections; cutpoints [2 ]]
60+ cutpoints = [support[1 ]; intersections; support [2 ]]
6161 @assert issorted (cutpoints) " cutpoints must be ordered"
6262 @assert length (unique (cutpoints)) == length (cutpoints) " cutpoints can't have duplicates"
6363 weights = [exp_integral (l, cutpoints[i], cutpoints[i + 1 ]) for (i, l) in enumerate (lines)]
6868
6969
7070"""
71- add_line_segment !(e::Envelop, l::Line)
71+ add_segment !(e::Envelop, l::Line)
7272Adds a new line segment to an envelop based on the value of its slope (slopes must be decreasing
7373always in the envelop). The cutpoints are automatically determined by intersecting the line with
7474the adjacent lines.
7575"""
76- function add_line_segment ! (e:: Envelop , l:: Line )
76+ function add_segment ! (e:: Envelop , l:: Line )
7777 # Find the position in sorted array with binary search
7878 pos = searchsortedfirst ([- line. slope for line in e. lines], - l. slope)
7979 # Find the new cutpoints
@@ -87,29 +87,27 @@ function add_line_segment!(e::Envelop, l::Line)
8787 else
8888 new_cut1 = intersection (l, e. lines[pos - 1 ])
8989 new_cut2 = intersection (l, e. lines[pos])
90- splice! (e. cutpoints, pos, [cut1, cut2 ])
90+ splice! (e. cutpoints, pos, [new_cut1, new_cut2 ])
9191 @assert issorted (e. cutpoints) " incompatible line: resulting intersection points aren't sorted"
9292 end
9393 # Insert the new line
9494 insert! (e. lines, pos, l)
95+ e. size += 1
9596 # Recompute weights (this could be done more efficiently in the future by updating the neccesary ones only)
9697 e. weights = [exp_integral (line, e. cutpoints[i], e. cutpoints[i + 1 ]) for (i, line) in enumerate (e. lines)]
9798end
9899
99100"""
100- sample (p::Envelop, n::Int )
101- Samples `n` elements iid from the density defined by the envelop `e` with it's exponential weights.
101+ sample_envelop (p::Envelop)
102+ Samples an element from the density defined by the envelop `e` with it's exponential weights.
102103See [`Envelop`](@ref) for details.
103104"""
104- function sample (e:: Envelop , n :: Int )
105+ function sample_envelop (e:: Envelop )
105106 # Randomly select lines based on envelop weights
106- line_num = sample (1 : e. size, weights (e. weights), n)
107- a = [l. slope for l in e. lines]
108- b = [l. intercept for l in e. lines]
109- # Generate random uniforms
110- u_list = rand (n)
107+ i = sample (1 : e. size, weights (e. weights))
108+ a, b = e. lines[i]. slope, e. lines[i]. intercept
111109 # Use the inverse CDF method for sampling
112- [ log (exp (- b[i]) * u * e. weights[i]* a[i] + exp (a[i] * e. cutpoints[i]))/ a[i] for (i, u) in zip (line_num, u_list)]
110+ log (exp (- b) * rand () * e. weights[i] * a + exp (a * e. cutpoints[i])) / a
113111end
114112
115113"""
@@ -163,14 +161,17 @@ interval in which f has positive value, and zero elsewhere.
163161mutable struct RejectionSampler
164162 objective:: Objective
165163 envelop:: Envelop
166-
164+ max_segments:: Int
165+ max_failed_rate:: Float64
166+ # Constructor when initial points are provided
167167 RejectionSampler (
168168 f:: Function ,
169169 support:: Tuple{Float64, Float64} ,
170170 init:: Tuple{Float64, Float64} ;
171171 max_segments:: Int = 10 ,
172- max_failed_factor :: Float64 = 0.001
172+ max_failed_rate :: Float64 = 0.001
173173 ) = begin
174+ @assert support[1 ] < support[2 ] " invalid support, not an interval"
174175 logf (x) = log (f (x))
175176 objective = Objective (logf)
176177 x1, x2 = init
@@ -181,59 +182,48 @@ mutable struct RejectionSampler
181182 b1, b2 = objective. logf (x1) - a1 * x1, objective. logf (x2) - a2 * x2
182183 line1, line2 = Line (a1, b1), Line (a2, b2)
183184 envelop = Envelop ([line1, line2], support)
184- new (objective, envelop)
185+ new (objective, envelop, max_segments, max_failed_rate )
185186 end
186187
188+ # Constructor for greedy search of starting points
187189 RejectionSampler (
188190 f:: Function ,
189191 support:: Tuple{Float64, Float64} ,
190192 δ:: Float64 = 0.5 ;
191- max_search_steps :: Int = 100 ,
193+ search_range :: Tuple{Float64, Float64} = ( - 10.0 , 10.0 ) ,
192194 kwargs...
193195 ) = begin
194196 logf (x) = log (f (x))
195197 grad (x) = ForwardDiff. derivative (logf, x)
196- x1, x2 = - δ, δ
197- i, j = 0 , 0
198- while (grad (x1) <= 0 || grad (x2) >= 0 ) && i < max_attempts
199- if grad (x1) <= 0
200- x1 -= δ
201- elsif grad (x2) >= 0
202- x2 -= δ
203- end
204- i += 1
205- end
206- while (grad (x1) <= 0 || grad (x2) >= 0 ) && j < max_attempts
207- if grad (x1) <= 0
208- x1 += δ
209- elsif grad (x2) >= 0
210- x2 += δ
211- end
212- j += 1
213- end
214- @assert i != max_attempts && j != max_attempts " couldn't find initial points, please provide them or verify that f is logconcave"
215- RejectionSampler (f, (x1, x2); kwargs... )
198+ grid = search_range[1 ]: δ:search_range [2 ]
199+ i1, i2 = findfirst (grad .(grid) .> 0. ), findfirst (grad .(grid) .< 0. )
200+ println (i1, i2)
201+ @assert (i1 > 0 ) && (i2 == 0 ) " couldn't find initial points, please provide them or change `search_range`"
202+ x1, x2 = grid[i1], grid[i2]
203+ RejectionSampler (f, support, (x1, x2); kwargs... )
216204 end
217205end
218206
219207"""
220-
208+ run_sampler!(sampler::RejectionSampler, n::Int)
209+ It draws `n` iid samples of the objective function of `sampler`, and at each iteration it adapts the envelop
210+ of `sampler` by adding new segments to its envelop.
221211"""
222- function run_sampler! (sampler :: RejectionSampler , n:: Int )
212+ function run_sampler! (s :: RejectionSampler , n:: Int )
223213 i = 0
224- failed, max_failed = 0 , trunc (Int, n / max_failed_factor )
214+ failed, max_failed = 0 , trunc (Int, n / s . max_failed_rate )
225215 out = zeros (n)
226216 while i < n
227- candidate = get_samples (sampler . envelop, 1 )[ 1 ]
228- acceptance_ratio = exp (sampler . objective. logf (candidate)) / eval_envelop (sampler . envelop, candidate)
217+ candidate = sample_envelop (s . envelop)
218+ acceptance_ratio = exp (s . objective. logf (candidate)) / eval_envelop (s . envelop, candidate)
229219 if rand () < acceptance_ratio
230220 i += 1
231221 out[i] = candidate
232222 else
233- if length (sampler . envelop. lines) <= max_segments
234- a = sampler . objective. grad (candidate)
235- b = sampler . objective. logf (candidate) - a * candidate
236- add_line_segment! (sampler . envelop, Line (a, b))
223+ if length (s . envelop. lines) <= s . max_segments
224+ a = s . objective. grad (candidate)
225+ b = s . objective. logf (candidate) - a * candidate
226+ add_segment! (s . envelop, Line (a, b))
237227 end
238228 failed += 1
239229 @assert failed < max_failed " max_failed_factor reached"
0 commit comments