Skip to content

Commit bf49b9d

Browse files
authored
overload NNlib._rng_compat_array (#62)
* overload NNlib._rng_compat_array * trigger CI
1 parent 420e4df commit bf49b9d

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

ext/NNlibCUDA/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "NNlibCUDA"
22
uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
3-
version = "0.2.5"
3+
version = "0.2.6"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313
[compat]
1414
Adapt = "3.3"
1515
CUDA = "3.11"
16-
NNlib = "0.8.14"
16+
NNlib = "0.8.15"
1717
julia = "1.6"
1818

1919
[extras]

ext/NNlibCUDA/src/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
NNlib._rng_from_array(::CuArray) = CUDA.default_rng()
22

3+
NNlib._rng_compat_array(rng::CUDA.RNG, A::CuArray) = nothing
4+
NNlib._rng_compat_array(rng::AbstractRNG, A::CuArray) = throw(ArgumentError(
5+
"cannot use rng::$(typeof(rng)) with array::CuArray, only CUDA's own RNG type works"))
6+
37
function divide_kernel!(xs, ys, max_idx)
48
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
59

0 commit comments

Comments
 (0)