@@ -93,10 +93,10 @@ export MetaSampler
9393"""
9494 MetaSampler(::NamedTuple)
9595
96- Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a
96+ Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a
9797batch from each sampler.
9898Used internally for algorithms that sample multiple times per epoch.
99- Note that a single "sampling" with a MetaSampler only increases the Trajectory controler
99+ Note that a single "sampling" with a MetaSampler only increases the Trajectory controler
100100count by 1, not by the number of internal samplers. This should be taken into account when
101101initializing an agent.
102102
@@ -131,15 +131,15 @@ export MultiBatchSampler
131131"""
132132 MultiBatchSampler(sampler, n)
133133
134- Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination
134+ Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination
135135with MetaSampler to allow different sampling rates between samplers.
136- Note that a single "sampling" with a MultiBatchSampler only increases the Trajectory
136+ Note that a single "sampling" with a MultiBatchSampler only increases the Trajectory
137137controler count by 1, not by `n`. This should be taken into account when
138138initializing an agent.
139139
140140# Example
141141```
142- MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3),
142+ MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3),
143143 critic = MultiBatchSampler(BatchSampler(100), 5))
144144```
145145"""
@@ -169,13 +169,13 @@ export NStepBatchSampler
169169 NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.GLOBAL_RNG)
170170
171171Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
172- The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
172+ The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
173173that in up to `n > 1` steps later in the buffer. The reward will be
174174the discounted sum of the `n` rewards, with `γ` as the discount factor.
175175
176- NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize` is set
176+ NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize` is set
177177to an integer > 1. This samples the (stacksize - 1) previous states. This is useful in the case
178- of partial observability, for example when the state is approximated by `stacksize` consecutive
178+ of partial observability, for example when the state is approximated by `stacksize` consecutive
179179frames.
180180"""
181181mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int} , R <: AbstractRNG }
@@ -187,17 +187,17 @@ mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRN
187187end
188188
189189NStepBatchSampler (t:: AbstractTraces ; kw... ) = NStepBatchSampler {keys(t)} (; kw... )
190- function NStepBatchSampler {names} (; n, γ, batchsize= 32 , stacksize= nothing , rng= Random. default_rng ()) where {names}
190+ function NStepBatchSampler {names} (; n, γ, batchsize= 32 , stacksize= nothing , rng= Random. default_rng ()) where {names}
191191 @assert n >= 1 " n must be ≥ 1."
192192 ss = stacksize == 1 ? nothing : stacksize
193193 NStepBatchSampler {names, typeof(ss), typeof(rng)} (n, γ, batchsize, ss, rng)
194194end
195195
196196# return a boolean vector of the valid sample indices given the stacksize and the truncated n for each index.
197- function valid_range (s:: NStepBatchSampler , eb:: EpisodesBuffer )
197+ function valid_range (s:: NStepBatchSampler , eb:: EpisodesBuffer )
198198 range = copy (eb. sampleable_inds)
199199 ns = Vector {Int} (undef, length (eb. sampleable_inds))
200- stacksize = isnothing (s. stacksize) ? 1 : s. stacksize
200+ stacksize = isnothing (s. stacksize) ? 1 : s. stacksize
201201 for idx in eachindex (range)
202202 step_number = eb. step_numbers[idx]
203203 range[idx] = step_number >= stacksize && eb. sampleable_inds[idx]
258258"""
259259 EpisodesSampler()
260260
261- A sampler that samples all Episodes present in the Trajectory and divides them into
261+ A sampler that samples all Episodes present in the Trajectory and divides them into
262262Episode containers. Truncated Episodes (e.g. due to the buffer capacity) are sampled as well.
263- There will be at most one truncated episode and it will always be the first one.
263+ There will be at most one truncated episode and it will always be the first one.
264264"""
265265struct EpisodesSampler{names}
266266end
@@ -295,7 +295,7 @@ function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)
295295 idx += 1
296296 end
297297 end
298-
298+
299299 return [make_episode (t, r, names) for r in ranges]
300300end
301301
@@ -304,29 +304,29 @@ end
304304"""
305305 MultiStepSampler{names}(batchsize, n, stacksize, rng)
306306
307- Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index
308- `x`. The samples are returned in an array of batchsize elements. For each element, n is
309- truncated by the end of its episode. This means that the dimensions of each sample are not
310- the same.
307+ Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index
308+ `x`. The samples are returned in an array of batchsize elements. For each element, n is
309+ truncated by the end of its episode. This means that the dimensions of each sample are not
310+ the same.
311311"""
312312struct MultiStepSampler{names, S <: Union{Nothing,Int} , R <: AbstractRNG }
313313 n:: Int
314314 batchsize:: Int
315315 stacksize:: S
316- rng:: R
316+ rng:: R
317317end
318318
319319MultiStepSampler (t:: AbstractTraces ; kw... ) = MultiStepSampler {keys(t)} (; kw... )
320- function MultiStepSampler {names} (; n:: Int , batchsize, stacksize= nothing , rng= Random. default_rng ()) where {names}
320+ function MultiStepSampler {names} (; n:: Int , batchsize, stacksize= nothing , rng= Random. default_rng ()) where {names}
321321 @assert n >= 1 " n must be ≥ 1."
322322 ss = stacksize == 1 ? nothing : stacksize
323323 MultiStepSampler {names, typeof(ss), typeof(rng)} (n, batchsize, ss, rng)
324324end
325325
326- function valid_range (s:: MultiStepSampler , eb:: EpisodesBuffer )
326+ function valid_range (s:: MultiStepSampler , eb:: EpisodesBuffer )
327327 range = copy (eb. sampleable_inds)
328328 ns = Vector {Int} (undef, length (eb. sampleable_inds))
329- stacksize = isnothing (s. stacksize) ? 1 : s. stacksize
329+ stacksize = isnothing (s. stacksize) ? 1 : s. stacksize
330330 for idx in eachindex (range)
331331 step_number = eb. step_numbers[idx]
332332 range[idx] = step_number >= stacksize && eb. sampleable_inds[idx]
@@ -353,7 +353,7 @@ function fetch(::MultiStepSampler, trace, ::Val, inds, ns)
353353 [trace[idx: (idx + ns[i] - 1 )] for (i,idx) in enumerate (inds)]
354354end
355355
356- function fetch (s:: MultiStepSampler{names, Int} , trace:: AbstractTrace , :: Union{Val{:state}, Val{:next_state}} , inds, ns) where {names}
356+ function fetch (s:: MultiStepSampler{names, Int} , trace:: AbstractTrace , :: Union{Val{:state}, Val{:next_state}} , inds, ns) where {names}
357357 [trace[[idx + i + n - 1 for i in - s. stacksize+ 1 : 0 , n in 1 : ns[j]]] for (j,idx) in enumerate (inds)]
358358end
359359
0 commit comments