Skip to content

Commit cc688a2

Browse files
authored
Merge pull request #6 from findmyway/add_tests
Adjust to RL.jl
2 parents dc220ed + edd4c52 commit cc688a2

16 files changed

+400
-195
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ version = "0.1.0"
55

66
[deps]
77
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
8+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
9+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
810
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
911
Term = "22787eb5-b846-44ae-b979-8e399b8463ab"
1012

README.md

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,54 +6,22 @@
66

77
## Design
88

9-
```
10-
┌────────────────────────────┐
11-
│(state=..., action=..., ...)│
12-
└──────────────┬─────────────┘
13-
push! │ append!
14-
┌───────────────────▼───────────────────┐
15-
│ Trajectory │
16-
│ ┌─────────────────────────────────┐ │
17-
│ │ Traces │ │
18-
│ │ ┌───────────────────┐ │ │
19-
│ │ state: │CircularArrayBuffer│ │ │
20-
│ │ └───────────────────┘ │ │
21-
│ │ ┌───────────────────┐ │ │
22-
│ │ action:│CircularArrayBuffer│ │ │
23-
│ │ └───────────────────┘ │ │
24-
│ │ ...... │ │
25-
│ └─────────────────────────────────┘ │
26-
| Sampler |
27-
└───────────────────┬───────────────────┘
28-
│ batch sampling
29-
┌──────────────▼─────────────┐
30-
│(state=..., action=..., ...)│
31-
└────────────────────────────┘
32-
```
9+
A typical example of `Trajectory`:
3310

11+
![](https://user-images.githubusercontent.com/5612003/167291629-0e2d4f0f-7c54-460c-a94f-9eb4148cdca0.png)
12+
13+
Exported APIs are:
14+
15+
```julia
16+
push!(trajectory; [trace_name=value]...)
17+
append!(trajectory; [trace_name=value]...)
18+
19+
for sample in trajectory
20+
# consume samples from the trajectory
21+
end
3422
```
35-
┌──────────────┐ ┌──────────────┐
36-
│Single Element│ │Batch Elements│
37-
└──────┬───────┘ └──────┬───────┘
38-
│ │
39-
push! └──────┐ ┌───────┘ append!
40-
│ │
41-
┌─────────────┼────┼─────────────────────────────┐
42-
│ ┌──▼────▼──┐ AsyncTrajectory │
43-
│ │Channel In│ │
44-
│ └─────┬────┘ │
45-
│ take! │ │
46-
│ ┌─────▼─────┐ push! ┌────────────┐ │
47-
│ │RateLimiter├──────────► Trajectory │ │
48-
│ └─────┬─────┘ append! └────*───────┘ │
49-
│ │ * │
50-
│ put! │********************** │
51-
│ │ batch sampling │
52-
│ ┌─────▼─────┐ │
53-
│ │Channel Out│ │
54-
│ └───────────┘ │
55-
└────────────────────────────────────────────────┘
56-
```
23+
24+
A wide variety of `container`s, `sampler`s, and `controler`s are provided. For the full list, please read the doc.
5725

5826
## Acknowledgement
5927

src/Trajectories.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
module Trajectories
22

3+
include("samplers.jl")
4+
include("controlers.jl")
35
include("traces.jl")
46
include("episodes.jl")
57
include("trajectory.jl")
6-
include("samplers.jl")
7-
include("async_trajectory.jl")
88
include("rendering.jl")
99
include("common/common.jl")
1010

src/async_trajectory.jl

Lines changed: 0 additions & 120 deletions
This file was deleted.

src/common/CircularArraySARTTraces.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,14 @@ function Random.rand(s::BatchSampler, t::CircularArraySARTTraces)
4141
terminal=t[:terminal][inds],
4242
next_state=t[:state][inds′],
4343
next_action=t[:state][inds′]
44-
)
44+
) |> s.transformer
45+
end
46+
47+
function Base.push!(t::CircularArraySARTTraces, x::NamedTuple{SA})
48+
if length(t[:state]) == length(t[:terminal]) + 1
49+
pop!(t[:state])
50+
pop!(t[:action])
51+
end
52+
push!(t[:state], x[:state])
53+
push!(t[:action], x[:action])
4554
end

src/common/CircularArraySLARTTraces.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function CircularArraySLARTTraces(;
3535
)
3636
end
3737

38-
function Random.rand(s::BatchSampler, t::CircularArraySLARTTraces)
38+
function sample(s::BatchSampler, t::CircularArraySLARTTraces)
3939
inds = rand(s.rng, 1:length(t), s.batch_size)
4040
inds′ = inds .+ 1
4141
(
@@ -47,5 +47,16 @@ function Random.rand(s::BatchSampler, t::CircularArraySLARTTraces)
4747
next_state=t[:state][inds′],
4848
next_legal_actions_mask=t[:legal_actions_mask][inds′],
4949
next_action=t[:state][inds′]
50-
)
51-
end
50+
) |> s.transformer
51+
end
52+
53+
function Base.push!(t::CircularArraySLARTTraces, x::NamedTuple{SLA})
54+
if length(t[:state]) == length(t[:terminal]) + 1
55+
pop!(t[:state])
56+
pop!(t[:legal_actions_mask])
57+
pop!(t[:action])
58+
end
59+
push!(t[:state], x[:state])
60+
push!(t[:legal_actions_mask], x[:legal_actions_mask])
61+
push!(t[:action], x[:action])
62+
end

src/common/common.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
using CircularArrayBuffers
22

3+
const SA = (:state, :action)
4+
const SLA = (:state, :legal_actions_mask, :action)
5+
const RT = (:reward, :terminal)
36
const SART = (:state, :action, :reward, :terminal)
47
const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action)
58
const SLART = (:state, :legal_actions_mask, :action, :reward, :terminal)

src/controlers.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
export InsertSampleRatioControler, AsyncInsertSampleRatioControler
2+
3+
mutable struct InsertSampleRatioControler
4+
ratio::Float64
5+
threshold::Int
6+
n_inserted::Int
7+
n_sampled::Int
8+
end
9+
10+
"""
11+
InsertSampleRatioControler(ratio, threshold)
12+
13+
Used in [`Trajectory`](@ref). The `threshold` means the minimal number of
14+
insertings before sampling. The `ratio` balances the number of insertings and
15+
the number of samplings.
16+
"""
17+
InsertSampleRatioControler(ratio, threshold) = InsertSampleRatioControler(ratio, threshold, 0, 0)
18+
19+
function on_insert!(c::InsertSampleRatioControler, n::Int)
20+
if n > 0
21+
c.n_inserted += n
22+
end
23+
end
24+
25+
function on_sample!(c::InsertSampleRatioControler)
26+
if c.n_inserted >= c.threshold
27+
if c.n_sampled <= (c.n_inserted - c.threshold) * c.ratio
28+
c.n_sampled += 1
29+
true
30+
end
31+
end
32+
end
33+
34+
#####
35+
36+
mutable struct AsyncInsertSampleRatioControler
37+
ratio::Float64
38+
threshold::Int
39+
n_inserted::Int
40+
n_sampled::Int
41+
ch_in::Channel
42+
ch_out::Channel
43+
end
44+
45+
function AsyncInsertSampleRatioControler(
46+
ratio,
47+
threshold,
48+
; ch_in_sz=1,
49+
ch_out_sz=1,
50+
n_inserted=0,
51+
n_sampled=0
52+
)
53+
AsyncInsertSampleRatioControler(
54+
ratio,
55+
threshold,
56+
n_inserted,
57+
n_sampled,
58+
Channel(ch_in_sz),
59+
Channel(ch_out_sz)
60+
)
61+
end

0 commit comments

Comments
 (0)