Skip to content

Commit ffc8576

Browse files
authored
Merge pull request #15 from JuliaReinforcementLearning/metasampler
Metasampler, MultiBatchSampler, SamplerControler
2 parents c2db98d + efffb6a commit ffc8576

File tree

9 files changed

+230
-88
lines changed

9 files changed

+230
-88
lines changed

src/Trajectories.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module Trajectories
22

33
include("samplers.jl")
4-
include("controlers.jl")
4+
include("controllers.jl")
55
include("traces.jl")
66
include("episodes.jl")
77
include("trajectory.jl")

src/controlers.jl

Lines changed: 0 additions & 61 deletions
This file was deleted.

src/controllers.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
export InsertSampleRatioController, InsertSampleController, AsyncInsertSampleRatioController
2+
3+
mutable struct InsertSampleRatioController
4+
ratio::Float64
5+
threshold::Int
6+
n_inserted::Int
7+
n_sampled::Int
8+
end
9+
10+
"""
11+
InsertSampleRatioController(ratio, threshold)
12+
13+
Used in [`Trajectory`](@ref). The `threshold` means the minimal number of
14+
insertings before sampling. The `ratio` balances the number of insertings and
15+
the number of samplings.
16+
"""
17+
InsertSampleRatioController(ratio, threshold) = InsertSampleRatioController(ratio, threshold, 0, 0)
18+
19+
function on_insert!(c::InsertSampleRatioController, n::Int)
20+
if n > 0
21+
c.n_inserted += n
22+
end
23+
end
24+
25+
function on_sample!(c::InsertSampleRatioController)
26+
if c.n_inserted >= c.threshold
27+
if c.n_sampled <= (c.n_inserted - c.threshold) * c.ratio
28+
c.n_sampled += 1
29+
true
30+
end
31+
end
32+
end
33+
34+
"""
35+
InsertSampleController(n, threshold)
36+
37+
Used in [`Trajectory`](@ref). The `threshold` means the minimal number of
38+
insertings before sampling. The `n` is the number of samples until stopping.
39+
"""
40+
mutable struct InsertSampleController
41+
n::Int
42+
threshold::Int
43+
n_inserted::Int
44+
n_sampled::Int
45+
end
46+
47+
InsertSampleController(n, threshold) = InsertSampleController(n, threshold, 0, 0)
48+
49+
function on_insert!(c::InsertSampleController, n::Int)
50+
if n > 0
51+
c.n_inserted += n
52+
end
53+
end
54+
55+
function on_sample!(c::InsertSampleController)
56+
if c.n_inserted >= c.threshold
57+
if c.n_sampled < c.n
58+
c.n_sampled += 1
59+
true
60+
end
61+
end
62+
end
63+
64+
#####
65+
66+
mutable struct AsyncInsertSampleRatioController
67+
ratio::Float64
68+
threshold::Int
69+
n_inserted::Int
70+
n_sampled::Int
71+
ch_in::Channel
72+
ch_out::Channel
73+
end
74+
75+
function AsyncInsertSampleRatioController(
76+
ratio,
77+
threshold,
78+
; ch_in_sz=1,
79+
ch_out_sz=1,
80+
n_inserted=0,
81+
n_sampled=0
82+
)
83+
AsyncInsertSampleRatioController(
84+
ratio,
85+
threshold,
86+
n_inserted,
87+
n_sampled,
88+
Channel(ch_in_sz),
89+
Channel(ch_out_sz)
90+
)
91+
end

src/rendering.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ function Base.convert(r::Type{Term.AbstractRenderable}, t::Trajectory; width=88)
110110
Panel(
111111
convert(r, t.container; width=width - 8) /
112112
Panel(convert(Term.Tree, t.sampler); title="sampler", style="yellow3", fit=true, width=width - 8) /
113-
Panel(convert(Term.Tree, t.controler); title="controler", style="yellow3", fit=true, width=width - 8);
113+
Panel(convert(Term.Tree, t.controller); title="controller", style="yellow3", fit=true, width=width - 8);
114114
title="Trajectory",
115115
style="yellow3",
116116
width=width,

src/samplers.jl

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
export BatchSampler
1+
export BatchSampler, MetaSampler, MultiBatchSampler
22

33
using Random
44

5-
struct BatchSampler
5+
abstract type AbstractSampler end
6+
7+
struct BatchSampler <: AbstractSampler
68
batch_size::Int
79
rng::Random.AbstractRNG
810
transformer::Any
@@ -16,3 +18,40 @@ Uniformly sample a batch of examples for each trace.
1618
See also [`sample`](@ref).
1719
"""
1820
BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=identity) = BatchSampler(batch_size, rng, identity)
21+
22+
"""
23+
MetaSampler(::NamedTuple)
24+
25+
Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a batch from each sampler.
26+
Used internally for algorithms that sample multiple times per epoch.
27+
28+
# Example
29+
30+
MetaSampler(policy = BatchSampler(10), critic = BatchSampler(100))
31+
"""
32+
struct MetaSampler{names, T} <: AbstractSampler
33+
samplers::NamedTuple{names, T}
34+
end
35+
36+
MetaSampler(; kw...) = MetaSampler(NamedTuple(kw))
37+
38+
function sample(s::MetaSampler, t)
39+
(;[(k, sample(v, t)) for (k,v) in pairs(s.samplers)]...)
40+
end
41+
42+
43+
"""
44+
MultiBatchSampler(sampler, n)
45+
46+
Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination with MetaSampler to allow different sampling rates between samplers.
47+
48+
# Example
49+
50+
MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3), critic = MultiBatchSampler(BatchSampler(100), 5))
51+
"""
52+
struct MultiBatchSampler{S <: AbstractSampler} <: AbstractSampler
53+
sampler::S
54+
n::Int
55+
end
56+
57+
sample(m::MultiBatchSampler, t) = [sample(m.sampler, t) for _ in 1:m.n]

src/trajectory.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ using Base.Threads
44

55

66
"""
7-
Trajectory(container, sampler, controler)
7+
Trajectory(container, sampler, controller)
88
99
The `container` is used to store experiences. Common ones are [`Traces`](@ref)
1010
or [`Episodes`](@ref). The `sampler` is used to sample experience batches from
11-
the `container`. The `controler` controls whether it is time to sample a batch
11+
the `container`. The `controller` controls whether it is time to sample a batch
1212
or not.
1313
1414
Supported methoes are:
@@ -21,35 +21,35 @@ Supported methoes are:
2121
Base.@kwdef struct Trajectory{C,S,T}
2222
container::C
2323
sampler::S
24-
controler::T
24+
controller::T
2525

2626
Trajectory(c::C, s::S, t::T) where {C,S,T} = new{C,S,T}(c, s, t)
2727

28-
function Trajectory(container::C, sampler::S, controler::T) where {C,S,T<:AsyncInsertSampleRatioControler}
28+
function Trajectory(container::C, sampler::S, controller::T) where {C,S,T<:AsyncInsertSampleRatioController}
2929
t = Threads.@spawn while true
30-
for msg in controler.ch_in
30+
for msg in controller.ch_in
3131
if msg.f === Base.push! || msg.f === Base.append!
3232
n_pre = length(container)
3333
msg.f(container, msg.args...; msg.kw...)
3434
n_post = length(container)
35-
controler.n_inserted += n_post - n_pre
35+
controller.n_inserted += n_post - n_pre
3636
else
3737
msg.f(container, msg.args...; msg.kw...)
3838
end
3939

40-
if controler.n_inserted >= controler.threshold
41-
if controler.n_sampled <= (controler.n_inserted - controler.threshold) * controler.ratio
40+
if controller.n_inserted >= controller.threshold
41+
if controller.n_sampled <= (controller.n_inserted - controller.threshold) * controller.ratio
4242
batch = sample(sampler, container)
43-
put!(controler.ch_out, batch)
44-
controler.n_sampled += 1
43+
put!(controller.ch_out, batch)
44+
controller.n_sampled += 1
4545
end
4646
end
4747
end
4848
end
4949

50-
bind(controler.ch_in, t)
51-
bind(controler.ch_out, t)
52-
new{C,S,T}(container, sampler, controler)
50+
bind(controller.ch_in, t)
51+
bind(controller.ch_out, t)
52+
new{C,S,T}(container, sampler, controller)
5353
end
5454
end
5555

@@ -60,7 +60,7 @@ function Base.push!(t::Trajectory, x)
6060
n_pre = length(t.container)
6161
push!(t.container, x)
6262
n_post = length(t.container)
63-
on_insert!(t.controler, n_post - n_pre)
63+
on_insert!(t.controller, n_post - n_pre)
6464
end
6565

6666
struct CallMsg
@@ -69,21 +69,21 @@ struct CallMsg
6969
kw::Any
7070
end
7171

72-
Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, args...; kw...) = put!(t.controler.ch_in, CallMsg(Base.push!, args, kw))
73-
Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, args...; kw...) = put!(t.controler.ch_in, CallMsg(Base.append!, args, kw))
72+
Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args...; kw...) = put!(t.controller.ch_in, CallMsg(Base.push!, args, kw))
73+
Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args...; kw...) = put!(t.controller.ch_in, CallMsg(Base.append!, args, kw))
7474

7575
Base.append!(t::Trajectory; kw...) = append!(t, values(kw))
7676

7777
function Base.append!(t::Trajectory, x)
7878
n_pre = length(t.container)
7979
append!(t.container, x)
8080
n_post = length(t.container)
81-
on_insert!(t.controler, n_post - n_pre)
81+
on_insert!(t.controller, n_post - n_pre)
8282
end
8383

8484
function Base.take!(t::Trajectory)
85-
res = on_sample!(t.controler)
86-
if isnothing(res)
85+
res = on_sample!(t.controller)
86+
if isnothing(res) && !isnothing(t.controller)
8787
nothing
8888
else
8989
sample(t.sampler, t.container)
@@ -101,5 +101,5 @@ end
101101

102102
Base.iterate(t::Trajectory, state) = iterate(t)
103103

104-
Base.iterate(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, args...) = iterate(t.controler.ch_out, args...)
105-
Base.take!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}) = take!(t.controler.ch_out)
104+
Base.iterate(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args...) = iterate(t.controller.ch_out, args...)
105+
Base.take!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}) = take!(t.controller.ch_out)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ using Test
44
@testset "Trajectories.jl" begin
55
include("traces.jl")
66
include("trajectories.jl")
7+
include("samplers.jl")
78
end

0 commit comments

Comments
 (0)