Skip to content

Commit 1560ff5

Browse files
Remove whitespace
1 parent e924768 commit 1560ff5

File tree

2 files changed

+27
-29
lines changed

2 files changed

+27
-29
lines changed

src/episodes.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ using ElasticArrays: ElasticArray, ElasticVector
55
"""
66
EpisodesBuffer(traces::AbstractTraces)
77
8-
Wraps an `AbstractTraces` object, usually the container of a `Trajectory`.
8+
Wraps an `AbstractTraces` object, usually the container of a `Trajectory`.
99
`EpisodesBuffer` tracks the indexes of the `traces` object that belong to the same episodes.
10-
To that end, it stores
10+
To that end, it stores
1111
1. an vector `sampleable_inds` of Booleans that determine whether an index in Traces is legally sampleable
1212
(i.e., it is not the index of a last state of an episode);
1313
2. a vector `episodes_lengths` that contains the total duration of the episode that each step belong to;
@@ -32,7 +32,7 @@ end
3232
"""
3333
PartialNamedTuple(::NamedTuple)
3434
35-
Wraps a NamedTuple to signal an EpisodesBuffer that it is pushed into that it should
35+
Wraps a NamedTuple to signal an EpisodesBuffer that it is pushed into that it should
3636
ignore the fact that this is a partial insertion. Used at the end of an episode to
3737
complete multiplex traces before moving to the next episode.
3838
"""
@@ -118,8 +118,6 @@ pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))
118118
end
119119
elseif traces_signature <: Tuple
120120
traces_signature = traces_signature.parameters
121-
122-
123121
for tr in traces_signature
124122
if !(tr <: MultiplexTraces)
125123
#push a duplicate of last element as a dummy element, should never be sampled.
@@ -171,7 +169,7 @@ function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
171169
return nothing
172170
end
173171

174-
function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTuple to push without incrementing the step number.
172+
function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTuple to push without incrementing the step number.
175173
push!(eb.traces, xs.namedtuple)
176174
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
177175
end

src/samplers.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ export MetaSampler
9393
"""
9494
MetaSampler(::NamedTuple)
9595
96-
Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a
96+
Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a
9797
batch from each sampler.
9898
Used internally for algorithms that sample multiple times per epoch.
99-
Note that a single "sampling" with a MetaSampler only increases the Trajectory controler
99+
Note that a single "sampling" with a MetaSampler only increases the Trajectory controler
100100
count by 1, not by the number of internal samplers. This should be taken into account when
101101
initializing an agent.
102102
@@ -131,15 +131,15 @@ export MultiBatchSampler
131131
"""
132132
MultiBatchSampler(sampler, n)
133133
134-
Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination
134+
Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination
135135
with MetaSampler to allow different sampling rates between samplers.
136-
Note that a single "sampling" with a MultiBatchSampler only increases the Trajectory
136+
Note that a single "sampling" with a MultiBatchSampler only increases the Trajectory
137137
controler count by 1, not by `n`. This should be taken into account when
138138
initializing an agent.
139139
140140
# Example
141141
```
142-
MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3),
142+
MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3),
143143
critic = MultiBatchSampler(BatchSampler(100), 5))
144144
```
145145
"""
@@ -169,13 +169,13 @@ export NStepBatchSampler
169169
NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.GLOBAL_RNG)
170170
171171
Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
172-
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
172+
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
173173
that in up to `n > 1` steps later in the buffer. The reward will be
174174
the discounted sum of the `n` rewards, with `γ` as the discount factor.
175175
176-
NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize` is set
176+
NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize` is set
177177
to an integer > 1. This samples the (stacksize - 1) previous states. This is useful in the case
178-
of partial observability, for example when the state is approximated by `stacksize` consecutive
178+
of partial observability, for example when the state is approximated by `stacksize` consecutive
179179
frames.
180180
"""
181181
mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
@@ -187,17 +187,17 @@ mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRN
187187
end
188188

189189
NStepBatchSampler(t::AbstractTraces; kw...) = NStepBatchSampler{keys(t)}(; kw...)
190-
function NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
190+
function NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
191191
@assert n >= 1 "n must be ≥ 1."
192192
ss = stacksize == 1 ? nothing : stacksize
193193
NStepBatchSampler{names, typeof(ss), typeof(rng)}(n, γ, batchsize, ss, rng)
194194
end
195195

196196
#return a boolean vector of the valid sample indices given the stacksize and the truncated n for each index.
197-
function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer)
197+
function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer)
198198
range = copy(eb.sampleable_inds)
199199
ns = Vector{Int}(undef, length(eb.sampleable_inds))
200-
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
200+
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
201201
for idx in eachindex(range)
202202
step_number = eb.step_numbers[idx]
203203
range[idx] = step_number >= stacksize && eb.sampleable_inds[idx]
@@ -258,9 +258,9 @@ end
258258
"""
259259
EpisodesSampler()
260260
261-
A sampler that samples all Episodes present in the Trajectory and divides them into
261+
A sampler that samples all Episodes present in the Trajectory and divides them into
262262
Episode containers. Truncated Episodes (e.g. due to the buffer capacity) are sampled as well.
263-
There will be at most one truncated episode and it will always be the first one.
263+
There will be at most one truncated episode and it will always be the first one.
264264
"""
265265
struct EpisodesSampler{names}
266266
end
@@ -295,7 +295,7 @@ function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)
295295
idx += 1
296296
end
297297
end
298-
298+
299299
return [make_episode(t, r, names) for r in ranges]
300300
end
301301

@@ -304,29 +304,29 @@ end
304304
"""
305305
MultiStepSampler{names}(batchsize, n, stacksize, rng)
306306
307-
Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index
308-
`x`. The samples are returned in an array of batchsize elements. For each element, n is
309-
truncated by the end of its episode. This means that the dimensions of each sample are not
310-
the same.
307+
Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index
308+
`x`. The samples are returned in an array of batchsize elements. For each element, n is
309+
truncated by the end of its episode. This means that the dimensions of each sample are not
310+
the same.
311311
"""
312312
struct MultiStepSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
313313
n::Int
314314
batchsize::Int
315315
stacksize::S
316-
rng::R
316+
rng::R
317317
end
318318

319319
MultiStepSampler(t::AbstractTraces; kw...) = MultiStepSampler{keys(t)}(; kw...)
320-
function MultiStepSampler{names}(; n::Int, batchsize, stacksize=nothing, rng=Random.default_rng()) where {names}
320+
function MultiStepSampler{names}(; n::Int, batchsize, stacksize=nothing, rng=Random.default_rng()) where {names}
321321
@assert n >= 1 "n must be ≥ 1."
322322
ss = stacksize == 1 ? nothing : stacksize
323323
MultiStepSampler{names, typeof(ss), typeof(rng)}(n, batchsize, ss, rng)
324324
end
325325

326-
function valid_range(s::MultiStepSampler, eb::EpisodesBuffer)
326+
function valid_range(s::MultiStepSampler, eb::EpisodesBuffer)
327327
range = copy(eb.sampleable_inds)
328328
ns = Vector{Int}(undef, length(eb.sampleable_inds))
329-
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
329+
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
330330
for idx in eachindex(range)
331331
step_number = eb.step_numbers[idx]
332332
range[idx] = step_number >= stacksize && eb.sampleable_inds[idx]
@@ -353,7 +353,7 @@ function fetch(::MultiStepSampler, trace, ::Val, inds, ns)
353353
[trace[idx:(idx + ns[i] - 1)] for (i,idx) in enumerate(inds)]
354354
end
355355

356-
function fetch(s::MultiStepSampler{names, Int}, trace::AbstractTrace, ::Union{Val{:state}, Val{:next_state}}, inds, ns) where {names}
356+
function fetch(s::MultiStepSampler{names, Int}, trace::AbstractTrace, ::Union{Val{:state}, Val{:next_state}}, inds, ns) where {names}
357357
[trace[[idx + i + n - 1 for i in -s.stacksize+1:0, n in 1:ns[j]]] for (j,idx) in enumerate(inds)]
358358
end
359359

0 commit comments

Comments
 (0)