Skip to content

Commit c202082

Browse files
Merge pull request #106 from JuliaML/cl/partial
fix partial=false bug
2 parents 1020a60 + 547ed93 commit c202082

File tree

5 files changed

+25
-9
lines changed

5 files changed

+25
-9
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
*.jl.mem
44
/Manifest.toml
55
docs/build
6+
.vscode

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLUtils"
22
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
33
authors = ["Carlo Lucibello <[email protected]> and contributors"]
4-
version = "0.2.7"
4+
version = "0.2.8"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/batchview.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,8 @@ struct BatchView{TElem,TData,TCollate} <: AbstractDataContainer
8787
batchsize::Int
8888
count::Int
8989
partial::Bool
90-
imax::Int
9190
end
9291

93-
9492
function BatchView(data::T; batchsize::Int=1, partial::Bool=true, collate=Val(nothing)) where {T}
9593
n = numobs(data)
9694
if n < batchsize
@@ -102,9 +100,8 @@ function BatchView(data::T; batchsize::Int=1, partial::Bool=true, collate=Val(no
102100
throw(ArgumentError("`collate` must be one of `nothing`, `true` or `false`."))
103101
end
104102
E = _batchviewelemtype(data, collate)
105-
imax = partial ? n : n - batchsize + 1
106103
count = partial ? ceil(Int, n / batchsize) : floor(Int, n / batchsize)
107-
BatchView{E,T,typeof(collate)}(data, batchsize, count, partial, imax)
104+
BatchView{E,T,typeof(collate)}(data, batchsize, count, partial)
108105
end
109106

110107
_batchviewelemtype(::TData, ::Val{nothing}) where TData =
@@ -155,10 +152,10 @@ Base.iterate(A::BatchView, state = 1) =
155152
(state > numobs(A)) ? nothing : (A[state], state + 1)
156153

157154
# Helper function to translate a batch-index into a range of observations.
158-
@inline function _batchrange(a::BatchView, batchindex::Int)
159-
@boundscheck (batchindex > a.count || batchindex < 0) && throw(BoundsError())
160-
startidx = (batchindex - 1) * a.batchsize + 1
161-
endidx = min(a.imax, startidx + a.batchsize -1)
155+
@inline function _batchrange(A::BatchView, batchindex::Int)
156+
@boundscheck (batchindex > A.count || batchindex < 0) && throw(BoundsError())
157+
startidx = (batchindex - 1) * A.batchsize + 1
158+
endidx = min(numobs(parent(A)), startidx + A.batchsize -1)
162159
return startidx:endidx
163160
end
164161

test/batchview.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,13 @@ using MLUtils: obsview
107107
# @test eltype(@inferred(BatchView(ObsView(var)))[1]) <: SubArray
108108
# end
109109
# end
110+
111+
@testset "partial=false" begin
112+
x = [1:12;]
113+
bv = BatchView(x, batchsize=5, partial=false)
114+
@test length(bv) == 2
115+
@test bv[1] == 1:5
116+
@test bv[2] == 6:10
117+
@test_throws BoundsError bv[3]
118+
end
110119
end

test/dataloader.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@
8181
end
8282
end
8383

84+
@testset "partial=false" begin
85+
x = [1:12;]
86+
d = DataLoader(x, batchsize=5, partial=false) |> collect
87+
@test length(d) == 2
88+
@test d[1] == 1:5
89+
@test d[2] == 6:10
90+
end
91+
8492
@testset "shuffle & rng" begin
8593
X4 = rand(2, 1000)
8694
d1 = DataLoader(X4, batchsize=2; shuffle=true)
@@ -97,6 +105,7 @@
97105
@test first(d1) == first(d2)
98106
end
99107

108+
100109
# numobs/getobs compatibility
101110
d = DataLoader(CustomType(), batchsize=2)
102111
@test first(d) == [1, 2]

0 commit comments

Comments
 (0)