Skip to content

Commit 0d9b2c6

Browse files
committed
Fix edge case BoundsError in split_at
1 parent a82e72e commit 0d9b2c6

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

src/basic.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ end
101101
Base.getindex(v::MemoryView, ::Colon) = v
102102
Base.@propagate_inbounds Base.view(v::MemoryView, idx::AbstractUnitRange) = v[idx]
103103

104+
# Efficient way to get `mem[1:include_last]`.
105+
# include_last must be in 0:length(mem)
104106
function truncate(mem::MemoryView, include_last::Integer)
105107
lst = Int(include_last)::Int
106108
@boundscheck if (lst % UInt) > length(mem) % UInt
@@ -109,6 +111,8 @@ function truncate(mem::MemoryView, include_last::Integer)
109111
typeof(mem)(unsafe, mem.ref, lst)
110112
end
111113

114+
# Efficient way to get `mem[from:end]`.
115+
# From must be in 1:length(mem).
112116
function truncate_start_nonempty(mem::MemoryView, from::Integer)
113117
frm = Int(from)::Int
114118
@boundscheck if ((frm - 1) % UInt) length(mem) % UInt
@@ -118,11 +122,14 @@ function truncate_start_nonempty(mem::MemoryView, from::Integer)
118122
typeof(mem)(unsafe, newref, length(mem) - frm + 1)
119123
end
120124

125+
# Efficient way to get `mem[from:end]`.
126+
# From must be in 1:length(mem)+1.
121127
function truncate_start(mem::MemoryView, from::Integer)
122128
frm = Int(from)::Int
123129
@boundscheck if ((frm - 1) % UInt) > length(mem) % UInt
124130
throw(BoundsError(mem, frm))
125131
end
132+
frm == 1 && return mem
126133
newref = @inbounds memoryref(mem.ref, frm - (from == length(mem) + 1))
127134
typeof(mem)(unsafe, newref, length(mem) - frm + 1)
128135
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,12 @@ end
434434
@test_throws BoundsError split_first(mem)
435435
@test_throws BoundsError split_last(mem)
436436
end
437+
438+
# Split empty mem at
439+
mem = MemoryView(UInt16[])
440+
(v1, v2) = split_at(mem, 1)
441+
@test v1 == v2
442+
@test isempty(v1)
437443
end
438444

439445
@testset "Split unaligned" begin

0 commit comments

Comments
 (0)