@@ -130,15 +130,17 @@ import ReinforcementLearningTrajectories.fetch
130130 batchsize = 4
131131 eb = EpisodesBuffer (CircularPrioritizedTraces (CircularArraySARTSATraces (capacity= 10 ), default_priority = 10f0 ))
132132 s1 = NStepBatchSampler (eb, n= n_horizon, γ= γ, batchsize= batchsize)
133-
134- push! (eb, (state = 1 , action = 1 ))
133+
134+ push! (eb, (state = 1 ,))
135135 for i = 1 : 5
136- push! (eb, (state = i+ 1 , action = i+ 1 , reward = i, terminal = i == 5 ))
136+ push! (eb, (state = i+ 1 , action = i, reward = i, terminal = i == 5 ))
137137 end
138- push! (eb, (state = 7 , action = 7 ))
139- for (j,i) = enumerate (8 : 11 )
140- push! (eb, (state = i, action = i, reward = i- 1 , terminal = false ))
138+ push! (eb, PartialNamedTuple ((action= 6 ,)))
139+ push! (eb, (state = 7 ,))
140+ for (j,i) = enumerate (7 : 10 )
141+ push! (eb, (state = i+ 1 , action = i, reward = i, terminal = i== 10 ))
141142 end
143+ push! (eb, PartialNamedTuple ((action = 11 ,)))
142144 weights, ns = ReinforcementLearningTrajectories. valid_range (s1, eb)
143145 inds = [i for i in eachindex (weights) if weights[i] == 1 ]
144146 batch = sample (s1, eb)
0 commit comments