Skip to content

Commit 73f2efb

Browse files
author
Dharanish
committed
Fix sampling of CircularPrioritizedTraces
1 parent 0ec5743 commit 73f2efb

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/samplers.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir
7474
t = e.traces
7575
p = collect(deepcopy(t.priorities))
7676
w = StatsBase.FrequencyWeights(p)
77-
w .*= e.sampleable_inds[1:end-1]
77+
w .*= e.sampleable_inds[1:length(t)]
7878
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
7979
NamedTuple{(:key, :priority, names...)}((t.keys[inds], p[inds], map(x -> collect(t.traces[Val(x)][inds]), names)...))
8080
end
@@ -247,7 +247,7 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any,
247247
p = collect(deepcopy(t.priorities))
248248
w = StatsBase.FrequencyWeights(p)
249249
valids, ns = valid_range(s,e)
250-
w .*= valids[1:end-1]
250+
w .*= valids[1:length(t)]
251251
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
252252
merge(
253253
(key=t.keys[inds], priority=p[inds]),
@@ -362,7 +362,7 @@ function StatsBase.sample(s::MultiStepSampler{names}, e::EpisodesBuffer{<:Any, <
362362
p = collect(deepcopy(t.priorities))
363363
w = StatsBase.FrequencyWeights(p)
364364
valids, ns = valid_range(s,e)
365-
w .*= valids[1:end-1]
365+
w .*= valids[1:length(t)]
366366
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
367367
merge(
368368
(key=t.keys[inds], priority=p[inds]),

0 commit comments

Comments
 (0)