Skip to content

Commit 9cdea3d

Browse files
committed
unsafe_wrap for symbols
1 parent 59da6c9 commit 9cdea3d

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/array.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ end
231231
function Base.unsafe_wrap(::Type{CuArray{T,N,M}},
232232
ptr::CuPtr{T}, dims::NTuple{N,Int};
233233
own::Bool=false, ctx::CuContext=context()) where {T,N,M}
234-
isbitstype(T) || throw(ArgumentError("Can only unsafe_wrap a pointer to a bits type"))
235234
sz = prod(dims) * sizeof(T)
236235

237236
# create a memory object

test/base/array.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,18 @@ end
155155
cpu_arr = unsafe_wrap(Array, cpu_ptr, 1)
156156
@test cpu_arr == [42]
157157
end
158+
159+
# symbols and tuples thereof
160+
let a = CuArray([:a])
161+
b = unsafe_wrap(CuArray, pointer(a), 1)
162+
@test typeof(b) <: CuArray{Symbol,1}
163+
@test size(b) == (1,)
164+
end
165+
let a = CuArray([(:a,:b)])
166+
b = unsafe_wrap(CuArray, pointer(a), 1)
167+
@test typeof(b) <: CuArray{Tuple{Symbol,Symbol},1}
168+
@test size(b) == (1,)
169+
end
158170
end
159171

160172
@testset "adapt" begin

0 commit comments

Comments
 (0)