@@ -87,10 +87,8 @@ struct BatchView{TElem,TData,TCollate} <: AbstractDataContainer
87
87
batchsize:: Int
88
88
count:: Int
89
89
partial:: Bool
90
- imax:: Int
91
90
end
92
91
93
-
94
92
function BatchView (data:: T ; batchsize:: Int = 1 , partial:: Bool = true , collate= Val (nothing )) where {T}
95
93
n = numobs (data)
96
94
if n < batchsize
@@ -102,9 +100,8 @@ function BatchView(data::T; batchsize::Int=1, partial::Bool=true, collate=Val(no
102
100
throw (ArgumentError (" `collate` must be one of `nothing`, `true` or `false`." ))
103
101
end
104
102
E = _batchviewelemtype (data, collate)
105
- imax = partial ? n : n - batchsize + 1
106
103
count = partial ? ceil (Int, n / batchsize) : floor (Int, n / batchsize)
107
- BatchView {E,T,typeof(collate)} (data, batchsize, count, partial, imax )
104
+ BatchView {E,T,typeof(collate)} (data, batchsize, count, partial)
108
105
end
109
106
110
107
_batchviewelemtype (:: TData , :: Val{nothing} ) where TData =
@@ -155,10 +152,10 @@ Base.iterate(A::BatchView, state = 1) =
155
152
(state > numobs (A)) ? nothing : (A[state], state + 1 )
156
153
157
154
# Helper function to translate a batch-index into a range of observations.
158
- @inline function _batchrange (a :: BatchView , batchindex:: Int )
159
- @boundscheck (batchindex > a . count || batchindex < 0 ) && throw (BoundsError ())
160
- startidx = (batchindex - 1 ) * a . batchsize + 1
161
- endidx = min (a . imax , startidx + a . batchsize - 1 )
155
+ @inline function _batchrange (A :: BatchView , batchindex:: Int )
156
+ @boundscheck (batchindex > A . count || batchindex < 0 ) && throw (BoundsError ())
157
+ startidx = (batchindex - 1 ) * A . batchsize + 1
158
+ endidx = min (numobs ( parent (A)) , startidx + A . batchsize - 1 )
162
159
return startidx: endidx
163
160
end
164
161
0 commit comments