@@ -13,7 +13,11 @@ using GPUArrays
1313using Adapt
1414
1515import KernelAbstractions
16- import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config, POCL
16+ import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
17+
18+ @static if isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.10
19+ import KernelAbstractions: POCL
20+ end
1721
1822
1923#
@@ -40,8 +44,29 @@ Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[])
4044# # executed on-device
4145
4246# array type
47+ @static if ! isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.9
48+ struct JLDeviceArray{T, N} <: AbstractDeviceArray{T, N}
49+ data:: Vector{UInt8}
50+ offset:: Int
51+ dims:: Dims{N}
52+ end
53+
54+ Base. elsize (:: Type{<:JLDeviceArray{T}} ) where {T} = sizeof (T)
4355
56+ Base. size (x:: JLDeviceArray ) = x. dims
57+ Base. sizeof (x:: JLDeviceArray ) = Base. elsize (x) * length (x)
4458
59+ Base. unsafe_convert (:: Type{Ptr{T}} , x:: JLDeviceArray{T} ) where {T} =
60+ convert (Ptr{T}, pointer (x. data)) + x. offset* Base. elsize (x)
61+
62+ # conversion of untyped data to a typed Array
63+ function typed_data (x:: JLDeviceArray{T} ) where {T}
64+ unsafe_wrap (Array, pointer (x), x. dims)
65+ end
66+
67+ @inline Base. getindex (A:: JLDeviceArray , index:: Integer ) = getindex (typed_data (A), index)
68+ @inline Base. setindex! (A:: JLDeviceArray , x, index:: Integer ) = setindex! (typed_data (A), x, index)
69+ end
4570
4671#
4772# Host abstractions
314339
315340# # GPUArrays interfaces
316341
317- function Adapt. adapt_storage (:: Adaptor , x:: JLArray{T,N} ) where {T,N}
318- arr = typed_data (x)
319- Adapt. adapt_storage (POCL. KernelAdaptor ([pointer (arr)]), arr)
342+ @static if ! isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.9
343+ Adapt. adapt_storage (:: Adaptor , x:: JLArray{T,N} ) where {T,N} =
344+ JLDeviceArray {T,N} (x. data[], x. offset, x. dims)
345+ else
346+ function Adapt. adapt_storage (:: Adaptor , x:: JLArray{T,N} ) where {T,N}
347+ arr = typed_data (x)
348+ Adapt. adapt_storage (POCL. KernelAdaptor ([pointer (arr)]), arr)
349+ end
320350end
321351
322352function GPUArrays. mapreducedim! (f, op, R:: AnyJLArray , A:: Union{AbstractArray,Broadcast.Broadcasted} ;
@@ -358,8 +388,18 @@ KernelAbstractions.allocate(::JLBackend, ::Type{T}, dims::Tuple) where T = JLArr
358388 return ndrange, workgroupsize, iterspace, dynamic
359389end
360390
361- function convert_to_cpu (obj:: Kernel{JLBackend, W, N, F} ) where {W, N, F}
362- return Kernel {typeof(KernelAbstractions.POCLBackend()), W, N, F} (KernelAbstractions. POCLBackend (), obj. f)
391+ @static if ! isdefined (JLArrays. KernelAbstractions, :isgpu ) # KA v0.9
392+ KernelAbstractions. isgpu (b:: JLBackend ) = false
393+ end
394+
395+ @static if ! isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.9
396+ function convert_to_cpu (obj:: Kernel{JLBackend, W, N, F} ) where {W, N, F}
397+ return Kernel {typeof(KernelAbstractions.CPU(; static = obj.backend.static)), W, N, F} (KernelAbstractions. CPU (; static = obj. backend. static), obj. f)
398+ end
399+ else
400+ function convert_to_cpu (obj:: Kernel{JLBackend, W, N, F} ) where {W, N, F}
401+ return Kernel {typeof(KernelAbstractions.POCLBackend()), W, N, F} (KernelAbstractions. POCLBackend (), obj. f)
402+ end
363403end
364404
365405function (obj:: Kernel{JLBackend} )(args... ; ndrange= nothing , workgroupsize= nothing )
370410
371411Adapt. adapt_storage (:: JLBackend , a:: Array ) = Adapt. adapt (JLArrays. JLArray, a)
372412Adapt. adapt_storage (:: JLBackend , a:: JLArrays.JLArray ) = a
373- Adapt. adapt_storage (:: KernelAbstractions.POCLBackend , a:: JLArrays.JLArray ) = convert (Array, a)
413+
414+ @static if ! isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.9
415+ Adapt. adapt_storage (:: KernelAbstractions.CPU , a:: JLArrays.JLArray ) = convert (Array, a)
416+ else
417+ Adapt. adapt_storage (:: KernelAbstractions.POCLBackend , a:: JLArrays.JLArray ) = convert (Array, a)
418+ end
374419
375420end
0 commit comments