Skip to content

Commit cde392d

Browse files
vchuravymaleadt
authored andcommitted
add backend
1 parent f546631 commit cde392d

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

src/abstract_gpu_interface.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ end
106106
# CUDAnative.__syncthreads()
107107
# end
108108

109-
110-
109+
abstract type GPUBackend end
110+
backend(::Type{T}) where T = error("Can't choose GPU backend for $T")
111111

112112
"""
113113
gpu_call(kernel::Function, A::GPUArray, args::Tuple, configuration = length(A))
@@ -124,7 +124,7 @@ Optionally, a launch configuration can be supplied in the following way:
124124
2) Pass a tuple of integer tuples to define blocks and threads per blocks!
125125
126126
"""
127-
function gpu_call(kernel, A::GPUArray, args::Tuple, configuration = length(A))
127+
function gpu_call(kernel, A::AbstractArray, args::Tuple, configuration = length(A))
128128
ITuple = NTuple{N, Integer} where N
129129
# If is a single integer, we assume it to be the global size / total number of threads one wants to launch
130130
thread_blocks = if isa(configuration, Integer)
@@ -148,8 +148,8 @@ function gpu_call(kernel, A::GPUArray, args::Tuple, configuration = length(A))
148148
`linear_index` will be inbetween 1:prod((blocks..., threads...))
149149
""")
150150
end
151-
_gpu_call(kernel, A, args, thread_blocks)
151+
_gpu_call(backend(typeof(A)), kernel, A, args, thread_blocks)
152152
end
153153

154154
# Internal GPU call function, that needs to be overloaded by the backends.
155-
_gpu_call(f, A, args, thread_blocks) = error("Not implemented")
155+
_gpu_call(::Any, f, A, args, thread_blocks) = error("Not implemented")

src/array.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ function JLArray{T, N}(size::NTuple{N, Integer}) where {T, N}
2121
JLArray{T, N}(Array{T, N}(undef, size), size)
2222
end
2323

24+
struct JLBackend <: GPUBackend end
25+
backend(::Type{<:JLArray}) = JLBackend()
2426

2527
## getters
2628

@@ -120,7 +122,7 @@ function AbstractDeviceArray(ptr::Array, shape::Vararg{Integer, N}) where N
120122
reshape(ptr, shape)
121123
end
122124

123-
function _gpu_call(f, A::JLArray, args::Tuple, blocks_threads::Tuple{T, T}) where T <: NTuple{N, Integer} where N
125+
function _gpu_call(::JLBackend, f, A, args::Tuple, blocks_threads::Tuple{T, T}) where T <: NTuple{N, Integer} where N
124126
blocks, threads = blocks_threads
125127
idx = ntuple(i-> 1, length(blocks))
126128
blockdim = blocks

src/broadcast.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ BroadcastStyle(::Type{T}) where {T<:GPUArray} = ArrayStyle{T}()
1616
BroadcastStyle(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
1717
BroadcastStyle(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where {T<:GPUArray} = BroadcastStyle(T)
1818

19+
backend(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where {T<:GPUArray} = backend(T)
20+
backend(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where {T<:GPUArray} = backend(T)
21+
1922
# This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
2023
# and we could define our methods in terms of Union{GPUArray, WrappedArray{<:Any, <:GPUArray}}
2124
const GPUDestArray = Union{GPUArray,

0 commit comments

Comments
 (0)