Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function explain_eltype(@nospecialize(T), depth=0; maxdepth=10)
msg *= explain_eltype(U, depth+1)
end
end
elseif Base.ismutabletype(T)
elseif Base.ismutabletype(T) && Base.datatype_fieldcount(T) != 0
msg = " "^depth * "$T is a mutable type\n"
elseif hasfieldcount(T)
msg = " "^depth * "$T is a struct that's not allocated inline\n"
Expand All @@ -47,9 +47,22 @@ end
# these are stored with a selector at the end (handled by Julia).
# 3. bitstype unions (`Union{Int, Float32}`, etc)
# these are stored contiguously and require a selector array (handled by us)
@inline function check_eltype(name, T)
eltype_is_invalid = !Base.allocatedinline(T) || (hasfieldcount(T) && any(!Base.allocatedinline, fieldtypes(T)))
if eltype_is_invalid
# As well as "mutable singleton" types like `Symbol` that use pointer-identity

function valid_type(@nospecialize(T))
if Base.allocatedinline(T)
if hasfieldcount(T)
return all(valid_type, fieldtypes(T))
end
return true
elseif Base.ismutabletype(T)
return Base.datatype_fieldcount(T) == 0
end
return false
end

@inline function check_eltype(name, T)
if !valid_type(T)
explanation = explain_eltype(T)
error("""
$name only supports element types that are allocated inline.
Expand Down Expand Up @@ -234,7 +247,7 @@ end
function Base.unsafe_wrap(::Type{CuArray{T,N,M}},
ptr::CuPtr{T}, dims::NTuple{N,Int};
own::Bool=false, ctx::CuContext=context()) where {T,N,M}
isbitstype(T) || throw(ArgumentError("Can only unsafe_wrap a pointer to a bits type"))
check_eltype("unsafe_wrap(CuArray, ...)", T)
sz = prod(dims) * aligned_sizeof(T)

# create a memory object
Expand Down
12 changes: 12 additions & 0 deletions test/base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,18 @@ end
cpu_arr = unsafe_wrap(Array, cpu_ptr, 1)
@test cpu_arr == [42]
end

# symbols and tuples thereof
let a = CuArray([:a])
b = unsafe_wrap(CuArray, pointer(a), 1)
@test typeof(b) <: CuArray{Symbol,1}
@test size(b) == (1,)
end
let a = CuArray([(:a,:b)])
b = unsafe_wrap(CuArray, pointer(a), 1)
@test typeof(b) <: CuArray{Tuple{Symbol,Symbol},1}
@test size(b) == (1,)
end
end

@testset "adapt" begin
Expand Down