Skip to content

Commit 320e3f8

Browse files
author
Dharanish
committed
Fix test of CircularPrioritizedTraces with SARTSA
The usage of SARTSA traces is more restrictive and should be done in this way
1 parent fbf054a commit 320e3f8

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

test/samplers.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,17 @@ import ReinforcementLearningTrajectories.fetch
130130
batchsize = 4
131131
eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0))
132132
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batchsize=batchsize)
133-
134-
push!(eb, (state = 1, action = 1))
133+
134+
push!(eb, (state = 1,))
135135
for i = 1:5
136-
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
136+
push!(eb, (state = i+1, action =i, reward = i, terminal = i == 5))
137137
end
138-
push!(eb, (state = 7, action = 7))
139-
for (j,i) = enumerate(8:11)
140-
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
138+
push!(eb, PartialNamedTuple((action=6,)))
139+
push!(eb, (state = 7,))
140+
for (j,i) = enumerate(7:10)
141+
push!(eb, (state = i+1, action =i, reward = i, terminal = i==10))
141142
end
143+
push!(eb, PartialNamedTuple((action = 11,)))
142144
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
143145
inds = [i for i in eachindex(weights) if weights[i] == 1]
144146
batch = sample(s1, eb)

0 commit comments

Comments
 (0)