diff --git a/src/PtrArrays.jl b/src/PtrArrays.jl index 4a3003f..6d8660e 100644 --- a/src/PtrArrays.jl +++ b/src/PtrArrays.jl @@ -24,28 +24,23 @@ struct PtrArray{T, N} <: DenseArray{T, N} ptr::Ptr{T} size::NTuple{N, Int} function PtrArray(ptr::Ptr{T}, dims::Vararg{Int, N}; check_dims=true) where {T, N} - check_dims && checked_dims(sizeof(T), dims...; message=:PtrArray) + check_dims && checked_dims(dims..., sizeof(T); message=:PtrArray) new{T, N}(ptr, dims) end end # Because Core.checked_dims is buggy 😢 -checked_dims(elsize::Int; message) = elsize -function checked_dims(elsize::Int, d0::Int, d::Int...; message) - overflow = false - neg = (d0+1) < 1 - zero = false # of d0==0 we won't have overflow since we go left to right - len = d0 - for di in d - len, o = Base.mul_with_overflow(len, di) - zero |= di === 0 - overflow |= o - neg |= (di+1) < 1 +function checked_dims(d0::Int, ds::Vararg{Int, N}; message) where N + @static VERSION >= v"1.10" && Base.@assume_effects :terminates_locally + for d in ds + d0+1 < 1 && throw(ArgumentError("invalid $message dimensions")) + d0, o = Base.mul_with_overflow(d0, d) + if o + !isempty(ds) && any(iszero, Base.front(ds)) && return 0 + throw(ArgumentError("invalid $message dimensions")) + end end - len, o = Base.mul_with_overflow(len, elsize) - err = o | neg | overflow & !zero - err && throw(ArgumentError("invalid $message dimensions")) - len + d0 end """ @@ -60,7 +55,7 @@ to call [`free`](@ref) on it when it is no longer needed. """ function malloc(::Type{T}, dims::Int...) where T isbitstype(T) || throw(ArgumentError("malloc only supports isbits types")) - ptr = Libc.malloc(checked_dims(sizeof(T), dims...; message=:malloc)) + ptr = Libc.malloc(checked_dims(dims..., sizeof(T); message=:malloc)) ptr === C_NULL && throw(OutOfMemoryError()) PtrArray(Ptr{T}(ptr), dims..., check_dims=false) end diff --git a/test/runtests.jl b/test/runtests.jl index 65ea77a..8be56bf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -68,9 +68,27 @@ function f(x, y) free(z) res end +function g(::Val{N}) where N + z = malloc(Int, ntuple(i -> 2, Val(N))...) + free(z) +end @testset "Allocations" begin @test f(10, 1:10) == 55 @test 0 == @allocated f(10, 1:10) + @test g(Val(0)) === nothing + @test g(Val(1)) === nothing + @test g(Val(2)) === nothing + @test g(Val(3)) === nothing + @test g(Val(4)) === nothing + @test g(Val(10)) === nothing + @test g(Val(15)) === nothing + @test 0 == @allocated g(Val(0)) + @test 0 == @allocated g(Val(1)) + VERSION >= v"1.11" && @test 0 == @allocated g(Val(2)) + VERSION >= v"1.11" && @test 0 == @allocated g(Val(3)) + VERSION >= v"1.11" && @test 0 == @allocated g(Val(4)) + VERSION >= v"1.11" && @test 0 == @allocated g(Val(10)) + VERSION >= v"1.11" && @test 0 == @allocated g(Val(15)) end @testset "Invalid dimensions" begin