Skip to content

Commit 9e3e662

Browse files
committed
basic adaptors
1 parent 30bf26d commit 9e3e662

File tree

3 files changed

+12
-0
lines changed

3 files changed

+12
-0
lines changed

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
julia 0.6
22
StaticArrays
3+
NNlib

src/GPUArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ include("convolution.jl")
1818
include("testsuite/testsuite.jl")
1919
include("jlbackend.jl")
2020
include("random.jl")
21+
include("nnlib.jl")
2122

2223
export GPUArray, gpu_call, thread_blocks_heuristic, global_size, synchronize_threads
2324
export linear_index, @linearidx, @cartesianidx, convolution!, device, synchronize

src/nnlib.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import NNlib: adapt, adapt_
2+
3+
adapt_(::Type{<:GPUArray}, xs::AbstractArray) =
4+
isbits(xs) ? xs : convert(GPUArray, xs)
5+
6+
adapt_(::Type{<:GPUArray{T}}, xs::AbstractArray{<:Real}) where T <: AbstractFloat =
7+
isbits(xs) ? xs : convert(GPUArray{T}, xs)
8+
9+
# Should go in CLArrays
10+
# cl(xs) = adapt(CLArray{Float32}, xs)

0 commit comments

Comments
 (0)