Skip to content
29 changes: 12 additions & 17 deletions src/PtrArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,23 @@ struct PtrArray{T, N} <: DenseArray{T, N}
ptr::Ptr{T}
size::NTuple{N, Int}
function PtrArray(ptr::Ptr{T}, dims::Vararg{Int, N}; check_dims=true) where {T, N}
check_dims && checked_dims(sizeof(T), dims...; message=:PtrArray)
check_dims && checked_dims(dims..., sizeof(T); message=:PtrArray)
new{T, N}(ptr, dims)
end
end

# Because Core.checked_dims is buggy 😢
checked_dims(elsize::Int; message) = elsize
function checked_dims(elsize::Int, d0::Int, d::Int...; message)
overflow = false
neg = (d0+1) < 1
zero = false # of d0==0 we won't have overflow since we go left to right
len = d0
for di in d
len, o = Base.mul_with_overflow(len, di)
zero |= di === 0
overflow |= o
neg |= (di+1) < 1
function checked_dims(d0::Int, ds::Vararg{Int, N}; message) where N
@static VERSION >= v"1.10" && Base.@assume_effects :terminates_locally
for d in ds
d0+1 < 1 && throw(ArgumentError("invalid $message dimensions"))
d0, o = Base.mul_with_overflow(d0, d)
if o
!isempty(ds) && any(iszero, Base.front(ds)) && return 0
throw(ArgumentError("invalid $message dimensions"))
end
end
len, o = Base.mul_with_overflow(len, elsize)
err = o | neg | overflow & !zero
err && throw(ArgumentError("invalid $message dimensions"))
len
d0
end

"""
Expand All @@ -60,7 +55,7 @@ to call [`free`](@ref) on it when it is no longer needed.
"""
function malloc(::Type{T}, dims::Int...) where T
isbitstype(T) || throw(ArgumentError("malloc only supports isbits types"))
ptr = Libc.malloc(checked_dims(sizeof(T), dims...; message=:malloc))
ptr = Libc.malloc(checked_dims(dims..., sizeof(T); message=:malloc))
ptr === C_NULL && throw(OutOfMemoryError())
PtrArray(Ptr{T}(ptr), dims..., check_dims=false)
end
Expand Down
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,27 @@ function f(x, y)
free(z)
res
end
function g(::Val{N}) where N
z = malloc(Int, ntuple(i -> 2, Val(N))...)
free(z)
end
@testset "Allocations" begin
@test f(10, 1:10) == 55
@test 0 == @allocated f(10, 1:10)
@test g(Val(0)) === nothing
@test g(Val(1)) === nothing
@test g(Val(2)) === nothing
@test g(Val(3)) === nothing
@test g(Val(4)) === nothing
@test g(Val(10)) === nothing
@test g(Val(15)) === nothing
@test 0 == @allocated g(Val(0))
@test 0 == @allocated g(Val(1))
VERSION >= v"1.11" && @test 0 == @allocated g(Val(2))
VERSION >= v"1.11" && @test 0 == @allocated g(Val(3))
VERSION >= v"1.11" && @test 0 == @allocated g(Val(4))
VERSION >= v"1.11" && @test 0 == @allocated g(Val(10))
VERSION >= v"1.11" && @test 0 == @allocated g(Val(15))
end

@testset "Invalid dimensions" begin
Expand Down
Loading