Skip to content

Commit d609559

Browse files
committed
merge latest master
2 parents 162d4b9 + ffc8576 commit d609559

File tree

9 files changed

+363
-89
lines changed

9 files changed

+363
-89
lines changed

src/Trajectories.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ include("patch.jl")
44

55
include("traces.jl")
66
include("samplers.jl")
7-
include("controlers.jl")
7+
include("controllers.jl")
88
include("trajectory.jl")
99
include("common/common.jl")
1010

src/controlers.jl

Lines changed: 0 additions & 61 deletions
This file was deleted.

src/controllers.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
export InsertSampleRatioController, InsertSampleController, AsyncInsertSampleRatioController
2+
3+
mutable struct InsertSampleRatioController
4+
ratio::Float64
5+
threshold::Int
6+
n_inserted::Int
7+
n_sampled::Int
8+
end
9+
10+
"""
11+
InsertSampleRatioController(ratio, threshold)
12+
13+
Used in [`Trajectory`](@ref). The `threshold` means the minimal number of
14+
insertings before sampling. The `ratio` balances the number of insertings and
15+
the number of samplings.
16+
"""
17+
InsertSampleRatioController(ratio, threshold) = InsertSampleRatioController(ratio, threshold, 0, 0)
18+
19+
function on_insert!(c::InsertSampleRatioController, n::Int)
20+
if n > 0
21+
c.n_inserted += n
22+
end
23+
end
24+
25+
function on_sample!(c::InsertSampleRatioController)
26+
if c.n_inserted >= c.threshold
27+
if c.n_sampled <= (c.n_inserted - c.threshold) * c.ratio
28+
c.n_sampled += 1
29+
true
30+
end
31+
end
32+
end
33+
34+
"""
35+
InsertSampleController(n, threshold)
36+
37+
Used in [`Trajectory`](@ref). The `threshold` means the minimal number of
38+
insertings before sampling. The `n` is the number of samples until stopping.
39+
"""
40+
mutable struct InsertSampleController
41+
n::Int
42+
threshold::Int
43+
n_inserted::Int
44+
n_sampled::Int
45+
end
46+
47+
InsertSampleController(n, threshold) = InsertSampleController(n, threshold, 0, 0)
48+
49+
function on_insert!(c::InsertSampleController, n::Int)
50+
if n > 0
51+
c.n_inserted += n
52+
end
53+
end
54+
55+
function on_sample!(c::InsertSampleController)
56+
if c.n_inserted >= c.threshold
57+
if c.n_sampled < c.n
58+
c.n_sampled += 1
59+
true
60+
end
61+
end
62+
end
63+
64+
#####
65+
66+
mutable struct AsyncInsertSampleRatioController
67+
ratio::Float64
68+
threshold::Int
69+
n_inserted::Int
70+
n_sampled::Int
71+
ch_in::Channel
72+
ch_out::Channel
73+
end
74+
75+
function AsyncInsertSampleRatioController(
76+
ratio,
77+
threshold,
78+
; ch_in_sz=1,
79+
ch_out_sz=1,
80+
n_inserted=0,
81+
n_sampled=0
82+
)
83+
AsyncInsertSampleRatioController(
84+
ratio,
85+
threshold,
86+
n_inserted,
87+
n_sampled,
88+
Channel(ch_in_sz),
89+
Channel(ch_out_sz)
90+
)
91+
end

src/rendering.jl

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
using Term
2+
3+
const TRACE_COLORS = ("bright_green", "hot_pink", "bright_blue", "light_coral", "bright_cyan", "sandy_brown", "violet")
4+
5+
Base.show(io::IO, ::MIME"text/plain", t::Union{Trace,Traces,Episode,Episodes,Trajectory}) = tprint(io, convert(Term.AbstractRenderable, t; width=displaysize(io)[2]) |> string)
6+
7+
inner_convert(::Type{Term.AbstractRenderable}, s::String; style="gray1", width=88) = Panel(s, width=width, style=style, justify=:center)
8+
inner_convert(t::Type{Term.AbstractRenderable}, x::Union{Symbol,Number}; kw...) = inner_convert(t, string(x); kw...)
9+
10+
function inner_convert(::Type{Term.AbstractRenderable}, x::AbstractArray; style="gray1", width=88)
11+
t = string(nameof(typeof(x)))
12+
s = replace(string(size(x)), " " => "")
13+
Panel(t * "\n" * s, style=style, justify=:center, width=width)
14+
end
15+
16+
function inner_convert(::Type{Term.AbstractRenderable}, x; style="gray1", width=88)
17+
s = string(nameof(typeof(x)))
18+
Panel(s, style=style, justify=:center, width=width)
19+
end
20+
21+
Base.convert(T::Type{Term.AbstractRenderable}, t::Trace{<:AbstractArray}; kw...) = convert(T, Trace(collect(eachslice(t.x, dims=ndims(t.x)))); kw..., type=typeof(t), subtitle="size: $(size(t.x))")
22+
23+
function Base.convert(
24+
::Type{Term.AbstractRenderable},
25+
t::Trace{<:AbstractVector};
26+
width=88,
27+
n_head=2,
28+
n_tail=1,
29+
name="Trace",
30+
style=TRACE_COLORS[mod1(hash(name), length(TRACE_COLORS))],
31+
type=typeof(t),
32+
subtitle="size: $(size(t.x))"
33+
)
34+
title = "$name: [italic]$type[/italic] "
35+
min_width = min(width, length(title) - 4)
36+
37+
n = length(t.x)
38+
if n == 0
39+
content = ""
40+
elseif 1 <= n <= n_head + n_tail
41+
content = mapreduce(x -> inner_convert(Term.AbstractRenderable, x, style=style, width=min_width - 6), /, t.x)
42+
else
43+
content = mapreduce(x -> inner_convert(Term.AbstractRenderable, x, style=style, width=min_width - 6), /, t.x[1:n_head]) /
44+
TextBox("...", justify=:center, width=min_width - 6) /
45+
mapreduce(x -> inner_convert(Term.AbstractRenderable, x, style=style, width=min_width - 6), /, t.x[end-n_tail+1:end])
46+
end
47+
Panel(content, width=min_width, title=title, subtitle=subtitle, subtitle_justify=:right, style=style, subtitle_style="yellow")
48+
end
49+
50+
function Base.convert(::Type{Term.AbstractRenderable}, t::Traces; width=88)
51+
max_len = mapreduce(length, max, t.traces)
52+
min_len = mapreduce(length, min, t.traces)
53+
if max_len - min_len == 1
54+
n_tails = [length(x) == max_len ? 2 : 1 for x in t.traces]
55+
else
56+
n_tails = [1 for x in t.traces]
57+
end
58+
N = length(t.traces)
59+
max_inner_width = ceil(Int, (width - 6 * 2) / N)
60+
Panel(
61+
mapreduce(((i, x),) -> convert(Term.AbstractRenderable, t[x]; width=max_inner_width, name=x, n_tail=n_tails[i], style=TRACE_COLORS[mod1(i, length(TRACE_COLORS))]), *, enumerate(keys(t))),
62+
title="Traces",
63+
style="yellow3",
64+
subtitle="$N traces in total",
65+
subtitle_justify=:right,
66+
width=width,
67+
fit=true
68+
)
69+
end
70+
71+
function Base.convert(::Type{Term.AbstractRenderable}, e::Episode; width=88)
72+
Panel(
73+
convert(Term.AbstractRenderable, e.traces; width=width - 6),
74+
title="Episode",
75+
style="green_yellow",
76+
subtitle=e[] ? "Episode END" : "Episode growing...",
77+
subtitle_justify=:right,
78+
width=width,
79+
fit=true
80+
)
81+
end
82+
83+
function Base.convert(::Type{Term.AbstractRenderable}, e::Episodes; width=88)
84+
n = length(e)
85+
if n == 0
86+
content = ""
87+
elseif n == 1
88+
content = convert(Term.AbstractRenderable, e[1], width=width - 6)
89+
elseif n == 2
90+
content = convert(Term.AbstractRenderable, e[1], width=width - 6) /
91+
convert(Term.AbstractRenderable, e[end], width=width - 6)
92+
else
93+
content = convert(Term.AbstractRenderable, e[1], width=width - 6) /
94+
TextBox("...", justify=:center, width=width - 6) /
95+
convert(Term.AbstractRenderable, e[end], width=width - 6)
96+
end
97+
98+
Panel(
99+
content,
100+
title="Episodes",
101+
subtitle="$n episodes in total",
102+
subtitle_justify=:right,
103+
width=width,
104+
fit=true,
105+
style="wheat1"
106+
)
107+
end
108+
109+
function Base.convert(r::Type{Term.AbstractRenderable}, t::Trajectory; width=88)
110+
Panel(
111+
convert(r, t.container; width=width - 8) /
112+
Panel(convert(Term.Tree, t.sampler); title="sampler", style="yellow3", fit=true, width=width - 8) /
113+
Panel(convert(Term.Tree, t.controller); title="controller", style="yellow3", fit=true, width=width - 8);
114+
title="Trajectory",
115+
style="yellow3",
116+
width=width,
117+
fit=true
118+
)
119+
end
120+
121+
# general converter
122+
123+
Base.convert(::Type{Term.Tree}, x) = Tree(to_tree_body(x); title=to_tree_title(x))
124+
Base.convert(::Type{Term.Tree}, x::Tree) = x
125+
126+
function to_tree_body(x)
127+
pts = propertynames(x)
128+
if length(pts) > 0
129+
Dict("$p => $(summary(getproperty(x, p)))" => to_tree_body(getproperty(x, p)) for p in pts)
130+
else
131+
x
132+
end
133+
end
134+
135+
to_tree_title(x) = "$(summary(x))"

src/samplers.jl

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
export BatchSampler
1+
export BatchSampler, MetaSampler, MultiBatchSampler
22

33
using MLUtils: batch
44

55
using Random
66

7-
struct BatchSampler
7+
abstract type AbstractSampler end
8+
9+
struct BatchSampler <: AbstractSampler
810
batch_size::Int
911
rng::Random.AbstractRNG
1012
transformer::Any
@@ -23,3 +25,38 @@ function sample(s::BatchSampler, t::AbstractTraces)
2325
inds = rand(s.rng, 1:length(t), s.batch_size)
2426
map(s.transformer, t[inds])
2527
end
28+
29+
"""
30+
MetaSampler(::NamedTuple)
31+
32+
Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a batch from each sampler.
33+
Used internally for algorithms that sample multiple times per epoch.
34+
35+
# Example
36+
37+
MetaSampler(policy = BatchSampler(10), critic = BatchSampler(100))
38+
"""
39+
struct MetaSampler{names,T} <: AbstractSampler
40+
samplers::NamedTuple{names,T}
41+
end
42+
43+
MetaSampler(; kw...) = MetaSampler(NamedTuple(kw))
44+
45+
sample(s::MetaSampler, t) = map(x -> sample(x, t), s.samplers)
46+
47+
48+
"""
49+
MultiBatchSampler(sampler, n)
50+
51+
Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination with MetaSampler to allow different sampling rates between samplers.
52+
53+
# Example
54+
55+
MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3), critic = MultiBatchSampler(BatchSampler(100), 5))
56+
"""
57+
struct MultiBatchSampler{S<:AbstractSampler} <: AbstractSampler
58+
sampler::S
59+
n::Int
60+
end
61+
62+
sample(m::MultiBatchSampler, t) = [sample(m.sampler, t) for _ in 1:m.n]

0 commit comments

Comments
 (0)