Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
coverage/
*.jl.*.cov
*.jl.cov
*.jl.mem
/Manifest.toml

.DS_Store
5 changes: 5 additions & 0 deletions src/LastDimSlices.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export LastDimSlices

using MacroTools: @forward

# See also https://github.com/JuliaLang/julia/pull/32310
5 changes: 2 additions & 3 deletions src/Trajectories.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
module Trajectories

include("samplers.jl")
include("controlers.jl")
include("traces.jl")
include("episodes.jl")
include("samplers.jl")
include("controlers.jl")
include("trajectory.jl")
include("rendering.jl")
include("common/common.jl")

end
37 changes: 2 additions & 35 deletions src/common/CircularArraySARTTraces.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
export CircularArraySARTTraces

const CircularArraySARTTraces = Traces{
SART,
<:Tuple{
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer}
}
}


function CircularArraySARTTraces(;
capacity::Int,
state=Int => (),
Expand All @@ -23,32 +12,10 @@ function CircularArraySARTTraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{(:state, :next_state)}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{(:action, :next_action)}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer
action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

function Random.rand(s::BatchSampler, t::CircularArraySARTTraces)
inds = rand(s.rng, 1:length(t), s.batch_size)
inds′ = inds .+ 1
(
state=t[:state][inds],
action=t[:action][inds],
reward=t[:reward][inds],
terminal=t[:terminal][inds],
next_state=t[:state][inds′],
next_action=t[:state][inds′]
) |> s.transformer
end

function Base.push!(t::CircularArraySARTTraces, x::NamedTuple{SA})
if length(t[:state]) == length(t[:terminal]) + 1
pop!(t[:state])
pop!(t[:action])
end
push!(t[:state], x[:state])
push!(t[:action], x[:action])
end
46 changes: 4 additions & 42 deletions src/common/CircularArraySLARTTraces.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
export CircularArraySLARTTraces

const CircularArraySLARTTraces = Traces{
SLART,
<:Tuple{
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer}
}
}


function CircularArraySLARTTraces(;
capacity::Int,
state=Int => (),
Expand All @@ -26,37 +14,11 @@ function CircularArraySLARTTraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{(:state, :next_state)}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{(:legal_actions_mask, :next_legal_actions_mask)}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
MultiplexTraces{(:action, :next_action)}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer
legal_actions_mask=CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1), # !!! legal_actions_mask is one step longer
action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

function sample(s::BatchSampler, t::CircularArraySLARTTraces)
inds = rand(s.rng, 1:length(t), s.batch_size)
inds′ = inds .+ 1
(
state=t[:state][inds],
legal_actions_mask=t[:legal_actions_mask][inds],
action=t[:action][inds],
reward=t[:reward][inds],
terminal=t[:terminal][inds],
next_state=t[:state][inds′],
next_legal_actions_mask=t[:legal_actions_mask][inds′],
next_action=t[:state][inds′]
) |> s.transformer
end

function Base.push!(t::CircularArraySLARTTraces, x::NamedTuple{SLA})
if length(t[:state]) == length(t[:terminal]) + 1
pop!(t[:state])
pop!(t[:legal_actions_mask])
pop!(t[:action])
end
push!(t[:state], x[:state])
push!(t[:legal_actions_mask], x[:legal_actions_mask])
push!(t[:action], x[:action])
end
end
8 changes: 0 additions & 8 deletions src/common/common.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
using CircularArrayBuffers

const SA = (:state, :action)
const SLA = (:state, :legal_actions_mask, :action)
const RT = (:reward, :terminal)
const SART = (:state, :action, :reward, :terminal)
const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action)
const SLART = (:state, :legal_actions_mask, :action, :reward, :terminal)
const SLARTSLA = (:state, :legal_actions_mask, :action, :reward, :terminal, :next_state, :next_legal_actions_mask, :next_action)

include("sum_tree.jl")
include("CircularArraySARTTraces.jl")
include("CircularArraySLARTTraces.jl")
53 changes: 19 additions & 34 deletions src/episodes.jl
Original file line number Diff line number Diff line change
@@ -1,54 +1,45 @@
export Episode, Episodes

using MLUtils: batch

"""
Episode(traces)

An `Episode` is a wrapper around [`Traces`](@ref). You can use `(e::Episode)[]`
to check/update whether the episode reaches a terminal or not.
"""
struct Episode{T}
struct Episode{T,E} <: AbstractVector{E}
traces::T
is_done::Ref{Bool}
is_terminated::Ref{Bool}
end

Base.getindex(e::Episode, s::Symbol) = getindex(e.traces, s)
Base.keys(e::Episode) = keys(e.traces)
Base.getindex(e::Episode, I) = getindex(e.traces, I)
Base.getindex(e::Episode) = getindex(e.is_terminated)
Base.setindex!(e::Episode, x::Bool) = setindex!(e.is_terminated, x)

Base.getindex(e::Episode) = getindex(e.is_done)
Base.setindex!(e::Episode, x::Bool) = setindex!(e.is_done, x)
Base.size(e::Episode) = size(e.traces)

Base.length(e::Episode) = length(e.traces)
Episode(t::T) where {T<:AbstractTraces} = Episode{T,eltype(t)}(t, Ref(false))

Episode(t::Traces) = Episode(t, Ref(false))

function Base.push!(t::Episode, x)
if t.is_done[]
throw(ArgumentError("The episode is already flagged as done!"))
else
push!(t.traces, x)
end
end

function Base.append!(t::Episode, x)
if t.is_done[]
throw(ArgumentError("The episode is already flagged as done!"))
else
append!(t.traces, x)
for f in (:push!, :pushfirst!, :append!, :prepend!)
@eval function Base.$f(t::Episode, x)
if t.is_terminated[]
throw(ArgumentError("The episode is already flagged as done!"))
else
$f(t.traces, x)
end
end
end

function Base.pop!(t::Episode)
pop!(t.traces)
t.is_done[] = false
t.is_terminated[] = false
end

Base.popfirst!(t::Episode) = popfirst!(t.traces)

function Base.empty!(t::Episode)
empty!(t.traces)
t.is_done[] = false
t.is_terminated[] = false
end

#####
Expand All @@ -58,13 +49,14 @@ end

A container for multiple [`Episode`](@ref)s. `init` is a parameterness function which return an [`Episode`](@ref).
"""
struct Episodes
struct Episodes <: AbstractVector{Episode}
init::Any
episodes::Vector{Episode}
inds::Vector{Tuple{Int,Int}}
end

Base.length(e::Episodes) = length(e.inds)
Base.size(e::Episodes) = size(e.inds)
Base.getindex(e::Episodes, I) = getindex(e.episodes, I)

function Base.push!(e::Episodes, x::Episode)
push!(e.episodes, x)
Expand Down Expand Up @@ -98,10 +90,3 @@ function Base.append!(e::Episodes, x)
push!(e.inds, (lengthe.episodes, i))
end
end

##

function sample(s::BatchSampler, e::Episodes)
inds = rand(s.rng, 1:length(t), s.batch_size)
batch([@view(s.episodes[e.inds[i][1]][e.inds[i][2]]) for i in inds]) |> s.transformer
end
135 changes: 0 additions & 135 deletions src/rendering.jl

This file was deleted.

Loading