11# GPUArrays.jl interface
22
3+ import KernelAbstractions
4+ import KernelAbstractions: Backend
35
46#
57# Device functionality
810
911# # execution
1012
11- struct oneArrayBackend <: AbstractGPUBackend end
12-
13- struct oneKernelContext <: AbstractKernelContext end
13+ struct oneArrayBackend <: Backend end
1414
1515@inline function GPUArrays. launch_heuristic (:: oneArrayBackend , f:: F , args:: Vararg{Any,N} ;
1616 elements:: Int , elements_per_thread:: Int ) where {F,N}
@@ -23,48 +23,6 @@ struct oneKernelContext <: AbstractKernelContext end
2323 return (threads= items, blocks= 32 )
2424end
2525
26- function GPUArrays. gpu_call (:: oneArrayBackend , f, args, threads:: Int , blocks:: Int ;
27- name:: Union{String,Nothing} )
28- @oneapi items= threads groups= blocks name= name f (oneKernelContext (), args... )
29- end
30-
31-
32- # # on-device
33-
34- # indexing
35-
36- GPUArrays. blockidx (ctx:: oneKernelContext ) = oneAPI. get_group_id (0 )
37- GPUArrays. blockdim (ctx:: oneKernelContext ) = oneAPI. get_local_size (0 )
38- GPUArrays. threadidx (ctx:: oneKernelContext ) = oneAPI. get_local_id (0 )
39- GPUArrays. griddim (ctx:: oneKernelContext ) = oneAPI. get_num_groups (0 )
40-
41- # math
42-
43- @inline GPUArrays. cos (ctx:: oneKernelContext , x) = oneAPI. cos (x)
44- @inline GPUArrays. sin (ctx:: oneKernelContext , x) = oneAPI. sin (x)
45- @inline GPUArrays. sqrt (ctx:: oneKernelContext , x) = oneAPI. sqrt (x)
46- @inline GPUArrays. log (ctx:: oneKernelContext , x) = oneAPI. log (x)
47-
48- # memory
49-
50- @inline function GPUArrays. LocalMemory (:: oneKernelContext , :: Type{T} , :: Val{dims} , :: Val{id}
51- ) where {T, dims, id}
52- ptr = oneAPI. emit_localmemory (Val (id), T, Val (prod (dims)))
53- oneDeviceArray (dims, LLVMPtr {T, onePI.AS.Local} (ptr))
54- end
55-
56- # synchronization
57-
58- @inline GPUArrays. synchronize_threads (:: oneKernelContext ) = oneAPI. barrier ()
59-
60-
61-
62- #
63- # Host abstractions
64- #
65-
66- GPUArrays. backend (:: Type{<:oneArray} ) = oneArrayBackend ()
67-
6826const GLOBAL_RNGs = Dict {ZeDevice,GPUArrays.RNG} ()
6927function GPUArrays. default_rng (:: Type{<:oneArray} )
7028 dev = device ()
0 commit comments