Skip to content

Commit 7d0d053

Browse files
author
Jeremiah Lewis
committed
add bounds check
1 parent 58bfea5 commit 7d0d053

File tree

1 file changed

+27
-19
lines changed

1 file changed

+27
-19
lines changed

src/episodes.jl

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,20 @@ function EpisodesBuffer(traces::AbstractTraces)
6969
end
7070
end
7171

72-
Base.getindex(es::EpisodesBuffer, idx...) = getindex(es.traces, idx...)
73-
Base.setindex!(es::EpisodesBuffer, idx...) = setindex!(es.traces, idx...)
74-
Base.size(es::EpisodesBuffer) = size(es.traces)
75-
Base.length(es::EpisodesBuffer) = length(es.traces)
76-
Base.keys(es::EpisodesBuffer) = keys(es.traces)
77-
Base.keys(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(es.traces.traces)
72+
function Base.getindex(es::EpisodesBuffer, idx::Int...)
73+
@boundscheck all(es.sampleable_inds[idx...])
74+
getindex(es.traces, idx...)
75+
end
76+
77+
function Base.getindex(es::EpisodesBuffer, idx...)
78+
getindex(es.traces, idx...)
79+
end
80+
81+
Base.setindex!(eb::EpisodesBuffer, idx...) = setindex!(eb.traces, idx...)
82+
Base.size(eb::EpisodesBuffer) = size(eb.traces)
83+
Base.length(eb::EpisodesBuffer) = length(eb.traces)
84+
Base.keys(eb::EpisodesBuffer) = keys(eb.traces)
85+
Base.keys(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(eb.traces.traces)
7886
function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where {names}
7987
s = nameof(typeof(eb))
8088
t = eb.traces
@@ -83,7 +91,7 @@ function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where
8391
end
8492

8593
ispartial_insert(traces::Traces, xs) = length(xs) < length(traces.traces) #this is the number of traces it contains not the number of steps.
86-
ispartial_insert(es::EpisodesBuffer, xs) = ispartial_insert(es.traces, xs)
94+
ispartial_insert(eb::EpisodesBuffer, xs) = ispartial_insert(eb.traces, xs)
8795
ispartial_insert(traces::CircularPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs)
8896

8997
function pad!(trace::Trace)
@@ -126,9 +134,9 @@ pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))
126134
return :($ex)
127135
end
128136

129-
fill_multiplex(es::EpisodesBuffer) = fill_multiplex(es.traces)
137+
fill_multiplex(eb::EpisodesBuffer) = fill_multiplex(eb.traces)
130138

131-
fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(es.traces.traces)
139+
fill_multiplex(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(eb.traces.traces)
132140

133141
function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
134142
push!(eb.traces, xs)
@@ -165,17 +173,17 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl
165173
end
166174

167175
for f in (:pop!, :popfirst!)
168-
@eval function Base.$f(es::EpisodesBuffer)
169-
$f(es.episodes_lengths)
170-
$f(es.sampleable_inds)
171-
$f(es.step_numbers)
172-
$f(es.traces)
176+
@eval function Base.$f(eb::EpisodesBuffer)
177+
$f(eb.episodes_lengths)
178+
$f(eb.sampleable_inds)
179+
$f(eb.step_numbers)
180+
$f(eb.traces)
173181
end
174182
end
175183

176-
function Base.empty!(es::EpisodesBuffer)
177-
empty!(es.traces)
178-
empty!(es.episodes_lengths)
179-
empty!(es.sampleable_inds)
180-
empty!(es.step_numbers)
184+
function Base.empty!(eb::EpisodesBuffer)
185+
empty!(eb.traces)
186+
empty!(eb.episodes_lengths)
187+
empty!(eb.sampleable_inds)
188+
empty!(eb.step_numbers)
181189
end

0 commit comments

Comments
 (0)