|
32 | 32 | # Tracker's implementation of ldiv isn't good. We'll use Zygote's instead. |
33 | 33 | zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B |
34 | 34 |
|
| 35 | +# fixes `randn` on GPU (https://github.com/TuringLang/DistributionsAD.jl/pull/108) |
35 | 36 | function adapt_randn(rng::AbstractRNG, x::AbstractArray, dims...) |
36 | | - adapt(typeof(x), randn(rng, eltype(x), dims...)) |
| 37 | + return adapt_randn(rng, eltype(x), x, dims...) |
37 | 38 | end |
38 | | - |
39 | | -# TODO: should be replaced by @non_differentiable when |
40 | | -# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/212 is fixed |
41 | | -function ChainRules.rrule(::typeof(adapt_randn), rng::AbstractRNG, x::AbstractArray, dims...) |
42 | | - function adapt_randn_pullback(ΔQ) |
43 | | - return (NO_FIELDS, Zero(), Zero(), map(_ -> Zero(), dims)...) |
44 | | - end |
45 | | - adapt_randn(rng, x, dims...), adapt_randn_pullback |
| 39 | +function adapt_randn(rng::AbstractRNG, ::Type{T}, x::AbstractArray, dims...) where {T} |
| 40 | + return adapt(parameterless_type(x), randn(rng, T, dims...)) |
46 | 41 | end |
| 42 | + |
| 43 | +# required by Adapt >= 3.3.0: https://github.com/SciML/OrdinaryDiffEq.jl/issues/1369 |
| 44 | +Base.@pure __parameterless_type(T) = Base.typename(T).wrapper |
| 45 | +parameterless_type(x) = parameterless_type(typeof(x)) |
| 46 | +parameterless_type(x::Type) = __parameterless_type(x) |
| 47 | + |
| 48 | +@non_differentiable adapt_randn(::Any...) |
0 commit comments