11# GPUArrays.jl interface
22
3+ import KernelAbstractions
4+ import KernelAbstractions: Backend
35
46#
57# Device functionality
810
911# # execution
1012
11- struct oneArrayBackend <: AbstractGPUBackend end
12-
13- struct oneKernelContext <: AbstractKernelContext end
13+ struct oneArrayBackend <: Backend end
1414
1515@inline function GPUArrays. launch_heuristic(:: oneArrayBackend , f:: F , args:: Vararg{Any,N} ;
1616 elements:: Int , elements_per_thread:: Int ) where {F,N}
@@ -23,48 +23,6 @@ struct oneKernelContext <: AbstractKernelContext end
2323 return (threads= items, blocks= 32 )
2424end
2525
26- function GPUArrays. gpu_call(:: oneArrayBackend , f, args, threads:: Int , blocks:: Int ;
27- name:: Union{String,Nothing} )
28- @oneapi items= threads groups= blocks name= name f(oneKernelContext(), args... )
29- end
30-
31-
32- # # on-device
33-
34- # indexing
35-
36- GPUArrays. blockidx(ctx:: oneKernelContext ) = oneAPI. get_group_id(0 )
37- GPUArrays. blockdim(ctx:: oneKernelContext ) = oneAPI. get_local_size(0 )
38- GPUArrays. threadidx(ctx:: oneKernelContext ) = oneAPI. get_local_id(0 )
39- GPUArrays. griddim(ctx:: oneKernelContext ) = oneAPI. get_num_groups(0 )
40-
41- # math
42-
43- @inline GPUArrays. cos(ctx:: oneKernelContext , x) = oneAPI. cos(x)
44- @inline GPUArrays. sin(ctx:: oneKernelContext , x) = oneAPI. sin(x)
45- @inline GPUArrays. sqrt(ctx:: oneKernelContext , x) = oneAPI. sqrt(x)
46- @inline GPUArrays. log(ctx:: oneKernelContext , x) = oneAPI. log(x)
47-
48- # memory
49-
50- @inline function GPUArrays. LocalMemory(:: oneKernelContext , :: Type{T} , :: Val{dims} , :: Val{id}
51- ) where {T, dims, id}
52- ptr = oneAPI. emit_localmemory(Val(id), T, Val(prod(dims)))
53- oneDeviceArray(dims, LLVMPtr{T, onePI. AS. Local}(ptr))
54- end
55-
56- # synchronization
57-
58- @inline GPUArrays. synchronize_threads(:: oneKernelContext ) = oneAPI. barrier()
59-
60-
61-
62- #
63- # Host abstractions
64- #
65-
66- GPUArrays. backend(:: Type{<:oneArray} ) = oneArrayBackend()
67-
6826const GLOBAL_RNGs = Dict{ZeDevice,GPUArrays. RNG}()
6927function GPUArrays. default_rng(:: Type{<:oneArray} )
7028 dev = device()
0 commit comments