Skip to content

Commit fbf054a

Browse files
author
Dharanish
committed
New test for CircularPrioritizedTraces with SARTSA
1 parent 73f2efb commit fbf054a

File tree

1 file changed

+62
-2
lines changed

1 file changed

+62
-2
lines changed

test/common.jl

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ end
129129
@test t isa CircularArraySLARTTraces
130130
end
131131

132-
@testset "CircularPrioritizedTraces" begin
132+
@testset "CircularPrioritizedTraces-SARTS" begin
133133
t = CircularPrioritizedTraces(
134-
CircularArraySARTSATraces(;
134+
CircularArraySARTSTraces(;
135135
capacity=3
136136
),
137137
default_priority=1.0f0
@@ -161,8 +161,68 @@ end
161161
@test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0
162162

163163
#EpisodesBuffer
164+
t = CircularPrioritizedTraces(
165+
CircularArraySARTSTraces(;
166+
capacity=10
167+
),
168+
default_priority=1.0f0
169+
)
170+
171+
eb = EpisodesBuffer(t)
172+
push!(eb, (state = 1, action = 1))
173+
for i = 1:5
174+
push!(eb, (state = i+1, action =i+1, reward = i, terminal = false))
175+
end
176+
push!(eb, (state = 7, action = 7))
177+
for (j,i) = enumerate(8:11)
178+
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
179+
end
180+
s = BatchSampler(1000)
181+
b = sample(s, eb)
182+
cm = counter(b[:state])
183+
@test !haskey(cm, 6)
184+
@test !haskey(cm, 11)
185+
@test all(in(keys(cm)), [1:5;7:10])
186+
187+
188+
eb[:priority, [1, 2]] = [0, 0]
189+
@test eb[:priority] == [zeros(2);ones(8)]
190+
end
191+
192+
@testset "CircularPrioritizedTraces-SARTSA" begin
164193
t = CircularPrioritizedTraces(
165194
CircularArraySARTSATraces(;
195+
capacity=3
196+
),
197+
default_priority=1.0f0
198+
)
199+
200+
push!(t, (state=0, action=0))
201+
202+
for i in 1:5
203+
push!(t, (reward=1.0f0, terminal=false, state=i, action=i))
204+
end
205+
206+
@test length(t) == 3
207+
208+
s = BatchSampler(5)
209+
210+
b = sample(s, t)
211+
212+
t[:priority, [1, 2]] = [0, 0]
213+
214+
# shouldn't be changed since [1,2] are old keys
215+
@test t[:priority] == [1.0f0, 1.0f0, 1.0f0]
216+
217+
t[:priority, [3, 4, 5]] = [0, 1, 0]
218+
219+
b = sample(s, t)
220+
221+
@test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0
222+
223+
#EpisodesBuffer
224+
t = CircularPrioritizedTraces(
225+
CircularArraySARTSTraces(;
166226
capacity=10
167227
),
168228
default_priority=1.0f0

0 commit comments

Comments
 (0)