diff --git a/src/array.jl b/src/array.jl index dbf3949b64..d66396bfce 100644 --- a/src/array.jl +++ b/src/array.jl @@ -24,7 +24,7 @@ function explain_eltype(@nospecialize(T), depth=0; maxdepth=10) msg *= explain_eltype(U, depth+1) end end - elseif Base.ismutabletype(T) + elseif Base.ismutabletype(T) && Base.datatype_fieldcount(T) != 0 msg = " "^depth * "$T is a mutable type\n" elseif hasfieldcount(T) msg = " "^depth * "$T is a struct that's not allocated inline\n" @@ -47,9 +47,22 @@ end # these are stored with a selector at the end (handled by Julia). # 3. bitstype unions (`Union{Int, Float32}`, etc) # these are stored contiguously and require a selector array (handled by us) -@inline function check_eltype(name, T) - eltype_is_invalid = !Base.allocatedinline(T) || (hasfieldcount(T) && any(!Base.allocatedinline, fieldtypes(T))) - if eltype_is_invalid +# As well as "mutable singleton" types like `Symbol` that use pointer-identity + +function valid_type(@nospecialize(T)) + if Base.allocatedinline(T) + if hasfieldcount(T) + return all(valid_type, fieldtypes(T)) + end + return true + elseif Base.ismutabletype(T) + return Base.datatype_fieldcount(T) == 0 + end + return false +end + +@inline function check_eltype(name, T) + if !valid_type(T) explanation = explain_eltype(T) error(""" $name only supports element types that are allocated inline. @@ -234,7 +247,7 @@ end function Base.unsafe_wrap(::Type{CuArray{T,N,M}}, ptr::CuPtr{T}, dims::NTuple{N,Int}; own::Bool=false, ctx::CuContext=context()) where {T,N,M} - isbitstype(T) || throw(ArgumentError("Can only unsafe_wrap a pointer to a bits type")) + check_eltype("unsafe_wrap(CuArray, ...)", T) sz = prod(dims) * aligned_sizeof(T) # create a memory object diff --git a/test/base/array.jl b/test/base/array.jl index 1be82df428..42f4668aff 100644 --- a/test/base/array.jl +++ b/test/base/array.jl @@ -173,6 +173,18 @@ end cpu_arr = unsafe_wrap(Array, cpu_ptr, 1) @test cpu_arr == [42] end + + # symbols and tuples thereof + let a = CuArray([:a]) + b = unsafe_wrap(CuArray, pointer(a), 1) + @test typeof(b) <: CuArray{Symbol,1} + @test size(b) == (1,) + end + let a = CuArray([(:a,:b)]) + b = unsafe_wrap(CuArray, pointer(a), 1) + @test typeof(b) <: CuArray{Tuple{Symbol,Symbol},1} + @test size(b) == (1,) + end end @testset "adapt" begin