Skip to content

Commit f35b436

Browse files
committed
Support arrays of symbols and tuple of symbols
1 parent 71da935 commit f35b436

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

src/array.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function explain_eltype(@nospecialize(T), depth=0; maxdepth=10)
2424
msg *= explain_eltype(U, depth+1)
2525
end
2626
end
27-
elseif Base.ismutabletype(T)
27+
elseif Base.ismutabletype(T) && Base.datatype_fieldcount(T) != 0
2828
msg = " "^depth * "$T is a mutable type\n"
2929
elseif hasfieldcount(T)
3030
msg = " "^depth * "$T is a struct that's not allocated inline\n"
@@ -47,8 +47,11 @@ end
4747
# these are stored with a selector at the end (handled by Julia).
4848
# 3. bitstype unions (`Union{Int, Float32}`, etc)
4949
# these are stored contiguously and require a selector array (handled by us)
50+
# As well as "mutable singleton" types like `Symbol` that use pointer-identity
5051
@inline function check_eltype(name, T)
51-
eltype_is_invalid = !Base.allocatedinline(T) || (hasfieldcount(T) && any(!Base.allocatedinline, fieldtypes(T)))
52+
eltype_is_invalid = !Base.allocatedinline(T) ||
53+
(hasfieldcount(T) && any(!Base.allocatedinline, fieldtypes(T)))
54+
5255
if eltype_is_invalid
5356
explanation = explain_eltype(T)
5457
error("""
@@ -234,7 +237,7 @@ end
234237
function Base.unsafe_wrap(::Type{CuArray{T,N,M}},
235238
ptr::CuPtr{T}, dims::NTuple{N,Int};
236239
own::Bool=false, ctx::CuContext=context()) where {T,N,M}
237-
isbitstype(T) || throw(ArgumentError("Can only unsafe_wrap a pointer to a bits type"))
240+
check_eltype("unsafe_wrap(CuArray, ...)", T)
238241
sz = prod(dims) * aligned_sizeof(T)
239242

240243
# create a memory object

test/base/array.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,18 @@ end
173173
cpu_arr = unsafe_wrap(Array, cpu_ptr, 1)
174174
@test cpu_arr == [42]
175175
end
176+
177+
# symbols and tuples thereof
178+
let a = CuArray([:a])
179+
b = unsafe_wrap(CuArray, pointer(a), 1)
180+
@test typeof(b) <: CuArray{Symbol,1}
181+
@test size(b) == (1,)
182+
end
183+
let a = CuArray([(:a,:b)])
184+
b = unsafe_wrap(CuArray, pointer(a), 1)
185+
@test typeof(b) <: CuArray{Tuple{Symbol,Symbol},1}
186+
@test size(b) == (1,)
187+
end
176188
end
177189

178190
@testset "adapt" begin

0 commit comments

Comments
 (0)