Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StackViews = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
Term = "22787eb5-b846-44ae-b979-8e399b8463ab"

[compat]
Expand Down
71 changes: 62 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,75 @@

## Design

A typical example of `Trajectory`:
The relationship of several concepts provided in this package:

![](https://user-images.githubusercontent.com/5612003/167291629-0e2d4f0f-7c54-460c-a94f-9eb4148cdca0.png)
```
┌───────────────────────────────────┐
│ Trajectory │
│ ┌───────────────────────────────┐ │
│ │ AbstractTraces │ │
│ │ ┌───────────────┐ │ │
│ │ :trace_A => │ AbstractTrace │ │ │
│ │ └───────────────┘ │ │
│ │ │ │
│ │ ┌───────────────┐ │ │
│ │ :trace_B => │ AbstractTrace │ │ │
│ │ └───────────────┘ │ │
│ │ ... ... │ │
│ └───────────────────────────────┘ │
│ ┌───────────┐ │
│ │ Sampler │ │
│ └───────────┘ │
│ ┌────────────┐ │
│ │ Controller │ │
│ └────────────┘ │
└───────────────────────────────────┘
```

## `Trajectory`

A `Trajectory` contains 3 parts:

Exported APIs are:
- A `container` to store data. (Usually an `AbstractTraces`)
- A `sampler` to determine how to sample a batch from `container`
- A `controller` to decide when to sample a new batch from the `container`

Typical usage:

```julia
push!(trajectory; [trace_name=value]...)
append!(trajectory; [trace_name=value]...)
julia> t = Trajectory(Traces(a=Int[], b=Bool[]), BatchSampler(3), InsertSampleRatioControler(1.0, 3));

julia> for i in 1:5
push!(t, (a=i, b=iseven(i)))
end

for sample in trajectory
# consume samples from the trajectory
end
julia> for batch in t
println(batch)
end
(a = [4, 5, 1], b = Bool[1, 0, 0])
(a = [3, 2, 4], b = Bool[0, 1, 1])
(a = [4, 1, 2], b = Bool[1, 0, 1])
```

A wide variety of `container`s, `sampler`s, and `controler`s are provided. For the full list, please read the doc.
**Traces**

- `Traces`
- `MultiplexTraces`
- `CircularSARTTraces`
- `Episode`
- `Episodes`

**Samplers**

- `BatchSampler`

**Controllers**

- `InsertSampleRatioController`
- `AsyncInsertSampleRatioController`


Please refer tests for common usage. (TODO: generate docs and add links to above data structures)

## Acknowledgement

Expand Down
5 changes: 0 additions & 5 deletions src/LastDimSlices.jl

This file was deleted.

3 changes: 2 additions & 1 deletion src/Trajectories.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module Trajectories

include("patch.jl")

include("traces.jl")
include("episodes.jl")
include("samplers.jl")
include("controlers.jl")
include("trajectory.jl")
Expand Down
14 changes: 12 additions & 2 deletions src/common/CircularArraySARTTraces.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
export CircularArraySARTTraces

const CircularArraySARTTraces = Traces{
SSAART,
<:Tuple{
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
}
}

function CircularArraySARTTraces(;
capacity::Int,
state=Int => (),
Expand All @@ -12,8 +22,8 @@ function CircularArraySARTTraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{(:state, :next_state)}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{(:action, :next_action)}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
Expand Down
17 changes: 14 additions & 3 deletions src/common/CircularArraySLARTTraces.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
export CircularArraySLARTTraces

const CircularArraySLARTTraces = Traces{
SSLLAART,
<:Tuple{
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{LL,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
}
}

function CircularArraySLARTTraces(;
capacity::Int,
state=Int => (),
Expand All @@ -14,9 +25,9 @@ function CircularArraySLARTTraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{(:state, :next_state)}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{(:legal_actions_mask, :next_legal_actions_mask)}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
MultiplexTraces{(:action, :next_action)}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{LL}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
Expand Down
7 changes: 7 additions & 0 deletions src/common/common.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
using CircularArrayBuffers

const SS = (:state, :next_state)
const LL = (:legal_actions_mask, :next_legal_actions_mask)
const AA = (:action, :next_action)
const RT = (:reward, :terminal)
const SSAART = (SS..., AA..., RT...)
const SSLLAART = (SS..., LL..., AA..., RT...)

include("sum_tree.jl")
include("CircularArraySARTTraces.jl")
include("CircularArraySLARTTraces.jl")
92 changes: 0 additions & 92 deletions src/episodes.jl

This file was deleted.

3 changes: 3 additions & 0 deletions src/patch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import MLUtils

MLUtils.batch(x::AbstractArray{<:Number}) = x
9 changes: 2 additions & 7 deletions src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,9 @@ Uniformly sample a batch of examples for each trace.

See also [`sample`](@ref).
"""
BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=identity) = BatchSampler(batch_size, rng, transformer)
BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=batch) = BatchSampler(batch_size, rng, transformer)

function sample(s::BatchSampler, t::AbstractTraces)
inds = rand(s.rng, 1:length(t), s.batch_size)
@view t[inds]
map(s.transformer, t[inds])
end

function sample(s::BatchSampler, e::Episodes)
inds = rand(s.rng, 1:length(t), s.batch_size)
batch([@view(s.episodes[e.inds[i][1]][e.inds[i][2]]) for i in inds]) |> s.transformer
end
Loading