Skip to content

Commit 0ec5743

Browse files
author
Dharanish
committed
Fix CircularPrioritizedTraces with SARTSA
1 parent 4a3be9c commit 0ec5743

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

src/common/CircularPrioritizedTraces.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@ end
1212
function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts}
1313
new_names = (:key, :priority, names...)
1414
new_Ts = Tuple{Int,Float32,Ts.parameters...}
15-
c = capacity(traces)
15+
if traces isa CircularArraySARTSATraces
16+
c = capacity(traces) - 1
17+
else
18+
c = capacity(traces)
19+
end
1620
CircularPrioritizedTraces{typeof(traces),new_names,new_Ts}(
1721
CircularVectorBuffer{Int}(c),
1822
SumTree(c),
@@ -34,6 +38,22 @@ function Base.push!(t::CircularPrioritizedTraces, x)
3438
end
3539
end
3640

41+
function Base.push!(t::CircularPrioritizedTraces{<:CircularArraySARTSATraces}, x)
42+
initial_length = length(t.traces)
43+
push!(t.traces, x)
44+
if length(t.traces) == 1
45+
push!(t.keys, 1)
46+
push!(t.priorities, t.default_priority)
47+
elseif length(t.traces) > 1 && (initial_length < length(t.traces) || initial_length == capacity(t.traces)-1 )
48+
# only add a key if the length changes after insertion of the tuple
49+
# or if the trace is already at capacity
50+
push!(t.keys, t.keys[end] + 1)
51+
push!(t.priorities, t.default_priority)
52+
else
53+
# may be partial inserting at the first step, ignore it
54+
end
55+
end
56+
3757
function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys)
3858
if k === :priority
3959
@assert length(vs) == length(keys)
@@ -48,6 +68,7 @@ function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys)
4868
end
4969

5070
Base.size(t::CircularPrioritizedTraces) = size(t.traces)
71+
max_length(t::CircularPrioritizedTraces) = max_length(t.traces)
5172

5273
function Base.getindex(ts::CircularPrioritizedTraces, s::Symbol)
5374
if s === :priority

src/episodes.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,20 @@ function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularArraySARTSATraces},
184184
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
185185
end
186186

187+
function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces{<:CircularArraySARTSATraces}}, xs::PartialNamedTuple{@NamedTuple{action::Int64}})
188+
if max_length(eb) == capacity(eb.traces)
189+
addition = (name => zero(eltype(eb.traces[name])) for name in [:state, :reward, :terminal])
190+
xs = (xs.namedtuple, addition)
191+
push!(eb.traces, xs)
192+
pop!(eb.traces[:state].trace)
193+
pop!(eb.traces[:reward])
194+
pop!(eb.traces[:terminal])
195+
else
196+
push!(eb.traces, xs.namedtuple)
197+
eb.sampleable_inds[end-1] = 1
198+
end
199+
end
200+
187201
for f in (:pop!, :popfirst!)
188202
@eval function Base.$f(eb::EpisodesBuffer)
189203
$f(eb.episodes_lengths)

0 commit comments

Comments
 (0)