Skip to content

Commit 285ba92

Browse files
authored
Merge pull request #368 from JuliaGPU/tb/convert
Simplify conversions.
2 parents bb9ca6d + 7c94697 commit 285ba92

File tree

1 file changed

+9
-58
lines changed

1 file changed

+9
-58
lines changed

src/host/construction.jl

Lines changed: 9 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# convenience and indirect construction
22

3+
# conversions from CPU arrays rely on constructors
4+
Base.convert(::Type{T}, a::AbstractArray) where {T<:AbstractGPUArray} = a isa T ? a : T(a)
5+
# TODO: can we implement constructors to and from ::AbstractArray here? by calling the undef
6+
# constructor and doing a `copyto!`. this is tricky, due to ambiguities, and no easy
7+
# way to go from <:AbstractGPUArray{T,N} to e.g. CuArray{S,N}
8+
9+
10+
## convenience constructors
11+
312
function Base.fill!(A::AnyGPUArray{T}, x) where T
413
length(A) == 0 && return A
514
gpu_call(A, convert(T, x)) do ctx, a, val
@@ -49,61 +58,3 @@ end
4958

5059
Base.one(x::AbstractGPUMatrix{T}) where {T} = _one(one(T), x)
5160
Base.oneunit(x::AbstractGPUMatrix{T}) where {T} = _one(oneunit(T), x)
52-
53-
54-
## collect & convert
55-
56-
function indexstyle(x::T) where T
57-
style = try
58-
Base.IndexStyle(x)
59-
catch
60-
nothing
61-
end
62-
style
63-
end
64-
65-
function collect_kernel(ctx::AbstractKernelContext, A, iter, ::IndexCartesian)
66-
idx = @cartesianidx(A)
67-
@inbounds A[idx...] = iter[idx...]
68-
return
69-
end
70-
71-
function collect_kernel(ctx::AbstractKernelContext, A, iter, ::IndexLinear)
72-
idx = linear_index(ctx)
73-
@inbounds A[idx] = iter[idx]
74-
return
75-
end
76-
77-
eltype_or(::Type{<: AbstractGPUArray}, or) = or
78-
eltype_or(::Type{<: AbstractGPUArray{T}}, or) where T = T
79-
eltype_or(::Type{<: AbstractGPUArray{T, N}}, or) where {T, N} = T
80-
81-
function Base.convert(AT::Type{<: AbstractGPUArray}, iter)
82-
isize = Base.IteratorSize(iter)
83-
style = indexstyle(iter)
84-
ettrait = Base.IteratorEltype(iter)
85-
if isbits(iter) && isa(isize, Base.HasShape) && style != nothing && isa(ettrait, Base.HasEltype)
86-
# We can collect on the GPU
87-
A = similar(AT, eltype_or(AT, eltype(iter)), size(iter))
88-
gpu_call(collect_kernel, A, iter, style)
89-
A
90-
else
91-
convert(AT, collect(iter))
92-
end
93-
end
94-
95-
function Base.convert(AT::Type{<: AbstractGPUArray{T, N}}, A::DenseArray{T, N}) where {T, N}
96-
copyto!(AT(undef, size(A)), A)
97-
end
98-
99-
function Base.convert(AT::Type{<: AbstractGPUArray{T1}}, A::DenseArray{T2, N}) where {T1, T2, N}
100-
copyto!(similar(AT, size(A)), convert(Array{T1, N}, A))
101-
end
102-
103-
function Base.convert(AT::Type{<: AbstractGPUArray}, A::DenseArray{T2, N}) where {T2, N}
104-
copyto!(similar(AT{T2}, size(A)), A)
105-
end
106-
107-
function Base.convert(AT::Type{Array{T, N}}, A::AbstractGPUArray{CT, CN}) where {T, N, CT, CN}
108-
convert(AT, copyto!(Array{CT, CN}(undef, size(A)), A))
109-
end

0 commit comments

Comments
 (0)