Skip to content

Commit 1d267b5

Browse files
author
mauriciogtec
committed
checked tests with new framework
1 parent c00c251 commit 1d267b5

File tree

2 files changed

+76
-80
lines changed

2 files changed

+76
-80
lines changed

src/AdaptiveRejectionSampling.jl

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using ForwardDiff # For automatic differentiation, no user nor approximate deriv
99
using StatsBase # To include the basic sample from array function
1010
# ------------------------------
1111
export Line, Objective, Envelop, RejectionSampler # Structures/classes
12-
export run_sampler!, eval_envelop # Methods
12+
export run_sampler!, eval_envelop, add_segment! # Methods
1313
# ------------------------------
1414

1515
"""
@@ -57,7 +57,7 @@ mutable struct Envelop
5757
Envelop(lines::Vector{Line}, support::Tuple{Float64, Float64}) = begin
5858
@assert issorted([l.slope for l in lines], rev = true) "line slopes must be decreasing"
5959
intersections = [intersection(lines[i], lines[i + 1]) for i in 1:(length(lines) - 1)]
60-
cutpoints = [support[1]; intersections; cutpoints[2]]
60+
cutpoints = [support[1]; intersections; support[2]]
6161
@assert issorted(cutpoints) "cutpoints must be ordered"
6262
@assert length(unique(cutpoints)) == length(cutpoints) "cutpoints can't have duplicates"
6363
weights = [exp_integral(l, cutpoints[i], cutpoints[i + 1]) for (i, l) in enumerate(lines)]
@@ -68,12 +68,12 @@ end
6868

6969

7070
"""
71-
add_line_segment!(e::Envelop, l::Line)
71+
add_segment!(e::Envelop, l::Line)
7272
Adds a new line segment to an envelop based on the value of its slope (slopes must be decreasing
7373
always in the envelop). The cutpoints are automatically determined by intersecting the line with
7474
the adjacent lines.
7575
"""
76-
function add_line_segment!(e::Envelop, l::Line)
76+
function add_segment!(e::Envelop, l::Line)
7777
# Find the position in sorted array with binary search
7878
pos = searchsortedfirst([-line.slope for line in e.lines], -l.slope)
7979
# Find the new cutpoints
@@ -87,29 +87,27 @@ function add_line_segment!(e::Envelop, l::Line)
8787
else
8888
new_cut1 = intersection(l, e.lines[pos - 1])
8989
new_cut2 = intersection(l, e.lines[pos])
90-
splice!(e.cutpoints, pos, [cut1, cut2])
90+
splice!(e.cutpoints, pos, [new_cut1, new_cut2])
9191
@assert issorted(e.cutpoints) "incompatible line: resulting intersection points aren't sorted"
9292
end
9393
# Insert the new line
9494
insert!(e.lines, pos, l)
95+
e.size += 1
9596
# Recompute weights (this could be done more efficiently in the future by updating the neccesary ones only)
9697
e.weights = [exp_integral(line, e.cutpoints[i], e.cutpoints[i + 1]) for (i, line) in enumerate(e.lines)]
9798
end
9899

99100
"""
100-
sample(p::Envelop, n::Int)
101-
Samples `n` elements iid from the density defined by the envelop `e` with it's exponential weights.
101+
sample_envelop(p::Envelop)
102+
Samples an element from the density defined by the envelop `e` with it's exponential weights.
102103
See [`Envelop`](@ref) for details.
103104
"""
104-
function sample(e::Envelop, n::Int)
105+
function sample_envelop(e::Envelop)
105106
# Randomly select lines based on envelop weights
106-
line_num = sample(1:e.size, weights(e.weights), n)
107-
a = [l.slope for l in e.lines]
108-
b = [l.intercept for l in e.lines]
109-
# Generate random uniforms
110-
u_list = rand(n)
107+
i = sample(1:e.size, weights(e.weights))
108+
a, b = e.lines[i].slope, e.lines[i].intercept
111109
# Use the inverse CDF method for sampling
112-
[log(exp(-b[i])*u*e.weights[i]*a[i] + exp(a[i]*e.cutpoints[i]))/a[i] for (i, u) in zip(line_num, u_list)]
110+
log(exp(-b) * rand() * e.weights[i] * a + exp(a * e.cutpoints[i])) / a
113111
end
114112

115113
"""
@@ -163,14 +161,17 @@ interval in which f has positive value, and zero elsewhere.
163161
mutable struct RejectionSampler
164162
objective::Objective
165163
envelop::Envelop
166-
164+
max_segments::Int
165+
max_failed_rate::Float64
166+
# Constructor when initial points are provided
167167
RejectionSampler(
168168
f::Function,
169169
support::Tuple{Float64, Float64},
170170
init::Tuple{Float64, Float64};
171171
max_segments::Int = 10,
172-
max_failed_factor::Float64 = 0.001
172+
max_failed_rate::Float64 = 0.001
173173
) = begin
174+
@assert support[1] < support[2] "invalid support, not an interval"
174175
logf(x) = log(f(x))
175176
objective = Objective(logf)
176177
x1, x2 = init
@@ -181,59 +182,48 @@ mutable struct RejectionSampler
181182
b1, b2 = objective.logf(x1) - a1 * x1, objective.logf(x2) - a2 * x2
182183
line1, line2 = Line(a1, b1), Line(a2, b2)
183184
envelop = Envelop([line1, line2], support)
184-
new(objective, envelop)
185+
new(objective, envelop, max_segments, max_failed_rate)
185186
end
186187

188+
# Constructor for greedy search of starting points
187189
RejectionSampler(
188190
f::Function,
189191
support::Tuple{Float64, Float64},
190192
δ::Float64 = 0.5;
191-
max_search_steps::Int = 100,
193+
search_range::Tuple{Float64, Float64} = (-10.0,10.0),
192194
kwargs...
193195
) = begin
194196
logf(x) = log(f(x))
195197
grad(x) = ForwardDiff.derivative(logf, x)
196-
x1, x2 = -δ, δ
197-
i, j = 0, 0
198-
while (grad(x1) <= 0 || grad(x2) >= 0) && i < max_attempts
199-
if grad(x1) <= 0
200-
x1 -= δ
201-
elsif grad(x2) >= 0
202-
x2 -= δ
203-
end
204-
i += 1
205-
end
206-
while (grad(x1) <= 0 || grad(x2) >= 0) && j < max_attempts
207-
if grad(x1) <= 0
208-
x1 += δ
209-
elsif grad(x2) >= 0
210-
x2 += δ
211-
end
212-
j += 1
213-
end
214-
@assert i != max_attempts && j != max_attempts "couldn't find initial points, please provide them or verify that f is logconcave"
215-
RejectionSampler(f, (x1, x2); kwargs...)
198+
grid = search_range[1]:δ:search_range[2]
199+
i1, i2 = findfirst(grad.(grid) .> 0.), findfirst(grad.(grid) .< 0.)
200+
println(i1, i2)
201+
@assert (i1 > 0) && (i2 == 0) "couldn't find initial points, please provide them or change `search_range`"
202+
x1, x2 = grid[i1], grid[i2]
203+
RejectionSampler(f, support, (x1, x2); kwargs...)
216204
end
217205
end
218206

219207
"""
220-
208+
run_sampler!(sampler::RejectionSampler, n::Int)
209+
It draws `n` iid samples of the objective function of `sampler`, and at each iteration it adapts the envelop
210+
of `sampler` by adding new segments to its envelop.
221211
"""
222-
function run_sampler!(sampler::RejectionSampler, n::Int)
212+
function run_sampler!(s::RejectionSampler, n::Int)
223213
i = 0
224-
failed, max_failed = 0, trunc(Int, n / max_failed_factor)
214+
failed, max_failed = 0, trunc(Int, n / s.max_failed_rate)
225215
out = zeros(n)
226216
while i < n
227-
candidate = get_samples(sampler.envelop, 1)[1]
228-
acceptance_ratio = exp(sampler.objective.logf(candidate)) / eval_envelop(sampler.envelop, candidate)
217+
candidate = sample_envelop(s.envelop)
218+
acceptance_ratio = exp(s.objective.logf(candidate)) / eval_envelop(s.envelop, candidate)
229219
if rand() < acceptance_ratio
230220
i += 1
231221
out[i] = candidate
232222
else
233-
if length(sampler.envelop.lines) <= max_segments
234-
a = sampler.objective.grad(candidate)
235-
b = sampler.objective.logf(candidate) - a * candidate
236-
add_line_segment!(sampler.envelop, Line(a, b))
223+
if length(s.envelop.lines) <= s.max_segments
224+
a = s.objective.grad(candidate)
225+
b = s.objective.logf(candidate) - a * candidate
226+
add_segment!(s.envelop, Line(a, b))
237227
end
238228
failed += 1
239229
@assert failed < max_failed "max_failed_factor reached"

test/runtests.jl

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,76 @@
11
import AdaptiveRejectionSampling
2-
arj = AdaptiveRejectionSampling
2+
ars = AdaptiveRejectionSampling
33

44
using Base.Test
55

66
@testset "Line" begin
7-
@test arj.Line(2.0, 3) isa arj.Line
8-
@test arj.intersection(arj.Line(-1.0, 1.0), arj.Line(1.0, -1.0)) == 1.0
9-
@test_throws AssertionError arj.intersection(arj.Line(-1.0, 1.0), arj.Line(-1.0, -1.0))
7+
@test ars.Line(2.0, 3) isa ars.Line
8+
@test ars.intersection(ars.Line(-1.0, 1.0), ars.Line(1.0, -1.0)) == 1.0
9+
@test_throws AssertionError ars.intersection(ars.Line(-1.0, 1.0), ars.Line(-1.0, -1.0))
1010
end
1111

1212
@testset "Envelop" begin
1313
@test begin
14-
l1 = arj.Line(1.0, 1.0)
15-
l2 = arj.Line(-3.0, 2.0)
16-
arj.Envelop([l1, l2])
17-
end isa arj.Envelop
14+
l1 = ars.Line(1.0, 1.0)
15+
l2 = ars.Line(-3.0, 2.0)
16+
support = (-Inf, Inf)
17+
ars.Envelop([l1, l2], support)
18+
end isa ars.Envelop
1819
@test_throws AssertionError begin
19-
l1 = arj.Line(1.0, 1.0)
20-
l2 = arj.Line(-3.0, 2.0)
21-
arj.Envelop([l2, l1])
20+
l1 = ars.Line(1.0, 1.0)
21+
l2 = ars.Line(-3.0, 2.0)
22+
ars.Envelop([l2, l1], (-Inf, Inf))
23+
end
24+
@test_throws AssertionError begin
25+
l1 = ars.Line(1.0, 1.0)
26+
l2 = ars.Line(-3.0, 2.0)
27+
ars.Envelop([l2, l1], (1.0, -1.0))
2228
end
2329
@test begin
24-
l1 = arj.Line(1.0, 1.0)
25-
l2 = arj.Line(-3.0, 2.0)
26-
l3 = arj.Line(-5.0, 5.0)
27-
p = arj.Envelop([l1, l2, l3])
28-
p.cutpoints == [0.25, 1.5]
30+
l1 = ars.Line(1.0, 1.0)
31+
l2 = ars.Line(-3.0, 2.0)
32+
l3 = ars.Line(-5.0, 5.0)
33+
e = ars.Envelop([l1, l2, l3], (-Inf, Inf))
34+
e.cutpoints == [-Inf, 0.25, 1.5, Inf]
2935
end
3036
@test begin
31-
l1 = arj.Line(1.0, 1.0)
32-
newline = arj.Line(0.0, 0.0)
33-
l3 = arj.Line(-1.0, 1.0)
34-
p = arj.Envelop([l1, l3])
35-
arj.add_line_segment!(p, newline)
36-
p.lines[2] == newline
37+
l1 = ars.Line(1.0, 1.0)
38+
newline = ars.Line(0.0, 0.0)
39+
l3 = ars.Line(-1.0, 1.0)
40+
e = ars.Envelop([l1, l3], (-Inf, Inf))
41+
ars.add_segment!(e, newline)
42+
e.lines[2] == newline
3743
end
3844
end
3945

4046
@testset "Objective" begin
41-
@test arj.Objective(x -> exp(-0.5 * x^2) / sqrt(2pi)) isa arj.Objective
47+
@test ars.Objective(x -> exp(-0.5 * x^2) / sqrt(2pi)) isa ars.Objective
4248
@test begin
43-
objective = arj.Objective(x -> exp(-0.5 * x^2) / sqrt(2pi))
49+
objective = ars.Objective(x -> exp(-0.5 * x^2) / sqrt(2pi))
4450
objective.logf(0.0) 0.3989422804014327
4551
end
4652
@test begin
47-
objective = arj.Objective(x -> exp(-0.5 * x^2) / sqrt(2pi))
53+
objective = ars.Objective(x -> exp(-0.5 * x^2) / sqrt(2pi))
4854
objective.grad(0.0) 0.0
4955
end
5056
@test begin
51-
objective = arj.Objective(x -> exp(-0.5 * x^2) / sqrt(2pi))
57+
objective = ars.Objective(x -> exp(-0.5 * x^2) / sqrt(2pi))
5258
objective.logf(Inf) == 0.0 && objective.logf(Inf) == 0.0
5359
end
5460
end
5561

5662
@testset "RejectionSampler" begin
57-
@test arj.RejectionSampler(x -> exp(-0.5 * x^2) / sqrt(2pi), -1.0, 1.0) isa arj.RejectionSampler
63+
@test ars.RejectionSampler(x -> exp(-0.5 * x^2) / sqrt(2pi), (-Inf, Inf), (-1.0, 1.0)) isa ars.RejectionSampler
5864
@test begin
59-
sampler = arj.RejectionSampler(x -> exp(-0.5 * x^2) / sqrt(2pi), -1.0, 1.0)
60-
arj.run_sampler!(sampler, 5) isa Vector{T} where T <: AbstractFloat
65+
sampler = ars.RejectionSampler(x -> exp(-0.5 * x^2) / sqrt(2pi), (-Inf, Inf), (-1.0, 1.0))
66+
ars.run_sampler!(sampler, 5) isa Vector{T} where T <: AbstractFloat
6167
end
6268
@test begin
63-
sampler = arj.RejectionSampler(x -> exp(-0.5 * x^2) / sqrt(2pi))
64-
arj.run_sampler!(sampler, 5) isa Vector{T} where T <: AbstractFloat
69+
sampler = ars.RejectionSampler(x -> exp(-0.5 * x^2) / sqrt(2pi), (-Inf, Inf))
70+
ars.run_sampler!(sampler, 5) isa Vector{T} where T <: AbstractFloat
6571
end
6672
@test_throws AssertionError begin
67-
sampler = arj.RejectionSampler(x -> exp(-0.5 * x) / sqrt(2pi))
68-
arj.run_sampler!(sampler, 5) isa Vector{T} where T <: AbstractFloat
73+
sampler = ars.RejectionSampler(x -> exp(-0.5 * x) / sqrt(2pi), (-Inf, Inf))
74+
ars.run_sampler!(sampler, 5) isa Vector{T} where T <: AbstractFloat
6975
end
7076
end

0 commit comments

Comments
 (0)