@@ -15,6 +15,10 @@ using Adapt
1515import KernelAbstractions
1616import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
1717
18+ @static if isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.10
19+ import KernelAbstractions: POCL
20+ end
21+
1822
1923#
2024# Device functionality
@@ -40,30 +44,30 @@ Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[])
4044# # executed on-device
4145
4246# array type
47+ @static if ! isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.9
48+ struct JLDeviceArray{T, N} <: AbstractDeviceArray{T, N}
49+ data:: Vector{UInt8}
50+ offset:: Int
51+ dims:: Dims{N}
52+ end
4353
44- struct JLDeviceArray{T, N} <: AbstractDeviceArray{T, N}
45- data:: Vector{UInt8}
46- offset:: Int
47- dims:: Dims{N}
48- end
54+ Base. elsize (:: Type{<:JLDeviceArray{T}} ) where {T} = sizeof (T)
4955
50- Base. elsize (:: Type{<:JLDeviceArray{T}} ) where {T} = sizeof (T)
56+ Base. size (x:: JLDeviceArray ) = x. dims
57+ Base. sizeof (x:: JLDeviceArray ) = Base. elsize (x) * length (x)
5158
52- Base. size ( x:: JLDeviceArray ) = x . dims
53- Base . sizeof (x :: JLDeviceArray ) = Base. elsize (x) * length (x)
59+ Base. unsafe_convert ( :: Type{Ptr{T}} , x:: JLDeviceArray{T} ) where {T} =
60+ convert (Ptr{T}, pointer (x . data)) + x . offset * Base. elsize (x)
5461
55- Base. unsafe_convert (:: Type{Ptr{T}} , x:: JLDeviceArray{T} ) where {T} =
56- convert (Ptr{T}, pointer (x. data)) + x. offset* Base. elsize (x)
62+ # conversion of untyped data to a typed Array
63+ function typed_data (x:: JLDeviceArray{T} ) where {T}
64+ unsafe_wrap (Array, pointer (x), x. dims)
65+ end
5766
58- # conversion of untyped data to a typed Array
59- function typed_data (x:: JLDeviceArray{T} ) where {T}
60- unsafe_wrap (Array, pointer (x), x. dims)
67+ @inline Base. getindex (A:: JLDeviceArray , index:: Integer ) = getindex (typed_data (A), index)
68+ @inline Base. setindex! (A:: JLDeviceArray , x, index:: Integer ) = setindex! (typed_data (A), x, index)
6169end
6270
63- @inline Base. getindex (A:: JLDeviceArray , index:: Integer ) = getindex (typed_data (A), index)
64- @inline Base. setindex! (A:: JLDeviceArray , x, index:: Integer ) = setindex! (typed_data (A), x, index)
65-
66-
6771#
6872# Host abstractions
6973#
@@ -236,7 +240,7 @@ Base.convert(::Type{T}, x::T) where T <: JLArray = x
236240
237241# # broadcast
238242
239- using Base. Broadcast: BroadcastStyle, Broadcasted
243+ import Base. Broadcast: BroadcastStyle, Broadcasted
240244
241245struct JLArrayStyle{N} <: AbstractGPUArrayStyle{N} end
242246JLArrayStyle {M} (:: Val{N} ) where {N,M} = JLArrayStyle {N} ()
335339
336340# # GPUArrays interfaces
337341
338- Adapt. adapt_storage (:: Adaptor , x:: JLArray{T,N} ) where {T,N} =
339- JLDeviceArray {T,N} (x. data[], x. offset, x. dims)
342+ @static if ! isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.9
343+ Adapt. adapt_storage (:: Adaptor , x:: JLArray{T,N} ) where {T,N} =
344+ JLDeviceArray {T,N} (x. data[], x. offset, x. dims)
345+ else
346+ function Adapt. adapt_storage (:: Adaptor , x:: JLArray{T,N} ) where {T,N}
347+ arr = typed_data (x)
348+ Adapt. adapt_storage (POCL. KernelAdaptor ([pointer (arr)]), arr)
349+ end
350+ end
340351
341352function GPUArrays. mapreducedim! (f, op, R:: AnyJLArray , A:: Union{AbstractArray,Broadcast.Broadcasted} ;
342353 init= nothing )
@@ -377,10 +388,18 @@ KernelAbstractions.allocate(::JLBackend, ::Type{T}, dims::Tuple) where T = JLArr
377388 return ndrange, workgroupsize, iterspace, dynamic
378389end
379390
380- KernelAbstractions. isgpu (b:: JLBackend ) = false
391+ @static if isdefined (JLArrays. KernelAbstractions, :isgpu ) # KA v0.9
392+ KernelAbstractions. isgpu (b:: JLBackend ) = false
393+ end
381394
382- function convert_to_cpu (obj:: Kernel{JLBackend, W, N, F} ) where {W, N, F}
383- return Kernel {typeof(KernelAbstractions.CPU(; static = obj.backend.static)), W, N, F} (KernelAbstractions. CPU (; static = obj. backend. static), obj. f)
395+ @static if ! isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.9
396+ function convert_to_cpu (obj:: Kernel{JLBackend, W, N, F} ) where {W, N, F}
397+ return Kernel {typeof(KernelAbstractions.CPU(; static = obj.backend.static)), W, N, F} (KernelAbstractions. CPU (; static = obj. backend. static), obj. f)
398+ end
399+ else
400+ function convert_to_cpu (obj:: Kernel{JLBackend, W, N, F} ) where {W, N, F}
401+ return Kernel {typeof(KernelAbstractions.POCLBackend()), W, N, F} (KernelAbstractions. POCLBackend (), obj. f)
402+ end
384403end
385404
386405function (obj:: Kernel{JLBackend} )(args... ; ndrange= nothing , workgroupsize= nothing )
391410
392411Adapt. adapt_storage (:: JLBackend , a:: Array ) = Adapt. adapt (JLArrays. JLArray, a)
393412Adapt. adapt_storage (:: JLBackend , a:: JLArrays.JLArray ) = a
394- Adapt. adapt_storage (:: KernelAbstractions.CPU , a:: JLArrays.JLArray ) = convert (Array, a)
413+
414+ @static if ! isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.9
415+ Adapt. adapt_storage (:: KernelAbstractions.CPU , a:: JLArrays.JLArray ) = convert (Array, a)
416+ else
417+ Adapt. adapt_storage (:: KernelAbstractions.POCLBackend , a:: JLArrays.JLArray ) = convert (Array, a)
418+ end
395419
396420end
0 commit comments