Skip to content

Commit e924768

Browse files
Fix SARTSTraces etc. capacity
1 parent de01fb3 commit e924768

File tree

5 files changed

+35
-26
lines changed

5 files changed

+35
-26
lines changed

src/common/CircularArraySARTSATraces.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ function CircularArraySARTSATraces(;
2424
reward_eltype, reward_size = reward
2525
terminal_eltype, terminal_size = terminal
2626

27-
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+2)) +
27+
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) +
2828
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity+1)) +
2929
Traces(
30-
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity+1),
31-
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity+1),
30+
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
31+
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
3232
)
3333
end
3434

35-
CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces)))
35+
CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) = minimum(map(capacity,t.traces))

src/common/CircularArraySARTSTraces.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ function CircularArraySARTSTraces(;
1717
state=Int => (),
1818
action=Int => (),
1919
reward=Float32 => (),
20-
terminal=Bool => ())
21-
20+
terminal=Bool => ()
21+
)
2222
state_eltype, state_size = state
2323
action_eltype, action_size = action
2424
reward_eltype, reward_size = reward
@@ -32,4 +32,4 @@ function CircularArraySARTSTraces(;
3232
)
3333
end
3434

35-
CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces)))
35+
CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) = minimum(map(capacity,t.traces))

src/common/CircularArraySLARTTraces.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ function CircularArraySLARTTraces(;
3434
)
3535
end
3636

37-
CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces)))
37+
CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) = minimum(map(capacity,t.traces))

src/common/CircularPrioritizedTraces.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@ end
1212
function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts}
1313
new_names = (:key, :priority, names...)
1414
new_Ts = Tuple{Int,Float32,Ts.parameters...}
15-
if traces isa CircularArraySARTSATraces
16-
c = capacity(traces) - 1
17-
else
18-
c = capacity(traces)
19-
end
15+
c = capacity(traces)
2016
CircularPrioritizedTraces{typeof(traces),new_names,new_Ts}(
2117
CircularVectorBuffer{Int}(c),
2218
SumTree(c),

test/common.jl

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,20 @@ end
3434
) |> gpu
3535

3636
@test t isa CircularArraySARTSATraces
37+
@test ReinforcementLearningTrajectories.capacity(t) == 3
38+
@test CircularArrayBuffers.capacity(t) == 3
3739

38-
push!(t, (state=ones(Float32, 2, 3),))
40+
push!(t, (state=ones(Float32, 2, 3),) |> gpu)
3941
push!(t, (action=ones(Float32, 2), next_state=ones(Float32, 2, 3) * 2) |> gpu)
4042
@test length(t) == 0
4143

4244
push!(t, (reward=1.0f0, terminal=false) |> gpu)
4345
@test length(t) == 0 # next_action is still missing
4446

45-
push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 2) |> gpu)
47+
push!(t, (action=ones(Float32, 2) * 2,) |> gpu)
48+
@test length(t) == 1
49+
50+
push!(t, (state=ones(Float32, 2, 3) * 3,) |> gpu)
4651
@test length(t) == 1
4752

4853
# this will trigger the scalar indexing of CuArray
@@ -71,29 +76,33 @@ end
7176

7277
@test length(t) == 3
7378

79+
push!(t, (action=ones(Float32, 2) * 6,) |> gpu)
80+
@test length(t) == 3
81+
7482
# this will trigger the scalar indexing of CuArray
7583
CUDA.@allowscalar @test t[1] == (
76-
state=ones(Float32, 2, 3) * 2,
77-
next_state=ones(Float32, 2, 3) * 3,
78-
action=ones(Float32, 2) * 2,
79-
next_action=ones(Float32, 2) * 3,
80-
reward=2.0f0,
84+
state=ones(Float32, 2, 3) * 3,
85+
next_state=ones(Float32, 2, 3) * 4,
86+
action=ones(Float32, 2) * 3,
87+
next_action=ones(Float32, 2) * 4,
88+
reward=3.0f0,
8189
terminal=false,
8290
)
8391
CUDA.@allowscalar @test t[end] == (
84-
state=ones(Float32, 2, 3) * 4,
85-
next_state=ones(Float32, 2, 3) * 5,
86-
action=ones(Float32, 2) * 4,
87-
next_action=ones(Float32, 2) * 5,
88-
reward=4.0f0,
92+
state=ones(Float32, 2, 3) * 5,
93+
next_state=ones(Float32, 2, 3) * 6,
94+
action=ones(Float32, 2) * 5,
95+
next_action=ones(Float32, 2) * 6,
96+
reward=5.0f0,
8997
terminal=false,
9098
)
9199

92100
batch = t[1:3]
93101
@test size(batch.state) == (2, 3, 3)
94102
@test size(batch.action) == (2, 3)
95-
@test batch.reward == [2.0, 3.0, 4.0] |> gpu
103+
@test batch.reward == [3.0, 4.0, 5.0] |> gpu
96104
@test batch.terminal == Bool[0, 0, 0] |> gpu
105+
97106
end
98107

99108
@testset "ElasticArraySARTSTraces" begin
@@ -127,6 +136,8 @@ end
127136
)
128137

129138
@test t isa CircularArraySLARTTraces
139+
@test ReinforcementLearningTrajectories.capacity(t) == 3
140+
@test CircularArrayBuffers.capacity(t) == 3
130141
end
131142

132143
@testset "CircularPrioritizedTraces-SARTS" begin
@@ -136,6 +147,7 @@ end
136147
),
137148
default_priority=1.0f0
138149
)
150+
@test ReinforcementLearningTrajectories.capacity(t) == 3
139151

140152
push!(t, (state=0, action=0))
141153

@@ -196,6 +208,7 @@ end
196208
),
197209
default_priority=1.0f0
198210
)
211+
@test ReinforcementLearningTrajectories.capacity(t) == 3
199212

200213
push!(t, (state=0, action=0))
201214

0 commit comments

Comments
 (0)