Skip to content

Commit 978b1fe

Browse files
authored
Fix errors with Adapt 3.3.0 (#161)
* Fix errors with Adapt 3.3.0 * Use `@non_differentiable` * Bump version * Do not drop ChainRules dependency
1 parent f074b8d commit 978b1fe

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DistributionsAD"
22
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
3-
version = "0.6.21"
3+
version = "0.6.22"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2525
[compat]
2626
Adapt = "2, 3"
2727
ChainRules = "0.7"
28-
ChainRulesCore = "0.9.9"
28+
ChainRulesCore = "0.9.21"
2929
Compat = "3.6"
3030
DiffRules = "0.1, 1.0"
3131
Distributions = "0.23.3, 0.24"

src/common.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,17 @@ end
3232
# Tracker's implementation of ldiv isn't good. We'll use Zygote's instead.
3333
zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B
3434

35+
# fixes `randn` on GPU (https://github.com/TuringLang/DistributionsAD.jl/pull/108)
3536
function adapt_randn(rng::AbstractRNG, x::AbstractArray, dims...)
36-
adapt(typeof(x), randn(rng, eltype(x), dims...))
37+
return adapt_randn(rng, eltype(x), x, dims...)
3738
end
38-
39-
# TODO: should be replaced by @non_differentiable when
40-
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/212 is fixed
41-
function ChainRules.rrule(::typeof(adapt_randn), rng::AbstractRNG, x::AbstractArray, dims...)
42-
function adapt_randn_pullback(ΔQ)
43-
return (NO_FIELDS, Zero(), Zero(), map(_ -> Zero(), dims)...)
44-
end
45-
adapt_randn(rng, x, dims...), adapt_randn_pullback
39+
function adapt_randn(rng::AbstractRNG, ::Type{T}, x::AbstractArray, dims...) where {T}
40+
return adapt(parameterless_type(x), randn(rng, T, dims...))
4641
end
42+
43+
# required by Adapt >= 3.3.0: https://github.com/SciML/OrdinaryDiffEq.jl/issues/1369
44+
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
45+
parameterless_type(x) = parameterless_type(typeof(x))
46+
parameterless_type(x::Type) = __parameterless_type(x)
47+
48+
@non_differentiable adapt_randn(::Any...)

src/forwarddiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
function adapt_randn(rng::AbstractRNG, x::AbstractArray{<:ForwardDiff.Dual}, dims...)
2-
adapt(typeof(x), randn(rng, ForwardDiff.valtype(eltype(x)), dims...))
2+
return adapt_randn(rng, ForwardDiff.valtype(eltype(x)), x, dims...)
33
end
44

55
## Binomial ##

0 commit comments

Comments
 (0)