|
129 | 129 | @test t isa CircularArraySLARTTraces |
130 | 130 | end |
131 | 131 |
|
132 | | -@testset "CircularPrioritizedTraces" begin |
| 132 | +@testset "CircularPrioritizedTraces-SARTS" begin |
133 | 133 | t = CircularPrioritizedTraces( |
134 | | - CircularArraySARTSATraces(; |
| 134 | + CircularArraySARTSTraces(; |
135 | 135 | capacity=3 |
136 | 136 | ), |
137 | 137 | default_priority=1.0f0 |
|
161 | 161 | @test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0 |
162 | 162 |
|
163 | 163 | #EpisodesBuffer |
| 164 | + t = CircularPrioritizedTraces( |
| 165 | + CircularArraySARTSTraces(; |
| 166 | + capacity=10 |
| 167 | + ), |
| 168 | + default_priority=1.0f0 |
| 169 | + ) |
| 170 | + |
| 171 | + eb = EpisodesBuffer(t) |
| 172 | + push!(eb, (state = 1, action = 1)) |
| 173 | + for i = 1:5 |
| 174 | + push!(eb, (state = i+1, action =i+1, reward = i, terminal = false)) |
| 175 | + end |
| 176 | + push!(eb, (state = 7, action = 7)) |
| 177 | + for (j,i) = enumerate(8:11) |
| 178 | + push!(eb, (state = i, action =i, reward = i-1, terminal = false)) |
| 179 | + end |
| 180 | + s = BatchSampler(1000) |
| 181 | + b = sample(s, eb) |
| 182 | + cm = counter(b[:state]) |
| 183 | + @test !haskey(cm, 6) |
| 184 | + @test !haskey(cm, 11) |
| 185 | + @test all(in(keys(cm)), [1:5;7:10]) |
| 186 | + |
| 187 | + |
| 188 | + eb[:priority, [1, 2]] = [0, 0] |
| 189 | + @test eb[:priority] == [zeros(2);ones(8)] |
| 190 | +end |
| 191 | + |
| 192 | +@testset "CircularPrioritizedTraces-SARTSA" begin |
164 | 193 | t = CircularPrioritizedTraces( |
165 | 194 | CircularArraySARTSATraces(; |
| 195 | + capacity=3 |
| 196 | + ), |
| 197 | + default_priority=1.0f0 |
| 198 | + ) |
| 199 | + |
| 200 | + push!(t, (state=0, action=0)) |
| 201 | + |
| 202 | + for i in 1:5 |
| 203 | + push!(t, (reward=1.0f0, terminal=false, state=i, action=i)) |
| 204 | + end |
| 205 | + |
| 206 | + @test length(t) == 3 |
| 207 | + |
| 208 | + s = BatchSampler(5) |
| 209 | + |
| 210 | + b = sample(s, t) |
| 211 | + |
| 212 | + t[:priority, [1, 2]] = [0, 0] |
| 213 | + |
| 214 | + # shouldn't be changed since [1,2] are old keys |
| 215 | + @test t[:priority] == [1.0f0, 1.0f0, 1.0f0] |
| 216 | + |
| 217 | + t[:priority, [3, 4, 5]] = [0, 1, 0] |
| 218 | + |
| 219 | + b = sample(s, t) |
| 220 | + |
| 221 | + @test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0 |
| 222 | + |
| 223 | + #EpisodesBuffer |
| 224 | + t = CircularPrioritizedTraces( |
| 225 | + CircularArraySARTSTraces(; |
166 | 226 | capacity=10 |
167 | 227 | ), |
168 | 228 | default_priority=1.0f0 |
|
0 commit comments