|
1 |
| -ChainRulesCore.@scalar_rule(xlogx(x::Real), (1 + log(x),)) |
2 |
| -ChainRulesCore.@scalar_rule(xlogy(x::Real, y::Real), (log(y), x / y,)) |
3 |
| -ChainRulesCore.@scalar_rule(xlog1py(x::Real, y::Real), (log1p(y), x / (1 + y),)) |
| 1 | +function _Ω_∂_xlogx(x::Real) |
| 2 | + logx = log(x) |
| 3 | + y = x * logx |
| 4 | + Ω = iszero(x) ? zero(y) : y |
| 5 | + ∂x = 1 + logx |
| 6 | + return Ω, ∂x |
| 7 | +end |
| 8 | +function ChainRulesCore.frule((_, Δx), ::typeof(xlogx), x::Real) |
| 9 | + Ω, ∂x = _Ω_∂_xlogx(x) |
| 10 | + ΔΩ = ∂x * Δx |
| 11 | + return Ω, ΔΩ |
| 12 | +end |
| 13 | +function ChainRulesCore.rrule(::typeof(xlogx), x::Real) |
| 14 | + Ω, ∂x = _Ω_∂_xlogx(x) |
| 15 | + xlogx_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), ∂x * ΔΩ) |
| 16 | + return Ω, xlogx_pullback |
| 17 | +end |
4 | 18 |
|
5 |
| -function ChainRulesCore.frule((_, Δx), ::typeof(xexpx), x::Real) |
| 19 | +function _Ω_∂_xlogy(x::Real, y::Real) |
| 20 | + logy = log(y) |
| 21 | + z = x * logy |
| 22 | + Ω = iszero(x) && !isnan(y) ? zero(z) : z |
| 23 | + ∂x = logy |
| 24 | + ∂y = x / y |
| 25 | + return Ω, ∂x, ∂y |
| 26 | +end |
| 27 | +function ChainRulesCore.frule((_, Δx, Δy), ::typeof(xlogy), x::Real, y::Real) |
| 28 | + Ω, ∂x, ∂y = _Ω_∂_xlogy(x, y) |
| 29 | + ΔΩ = muladd(∂x, Δx, ∂y * Δy) |
| 30 | + return Ω, ΔΩ |
| 31 | +end |
| 32 | +function ChainRulesCore.rrule(::typeof(xlogy), x::Real, y::Real) |
| 33 | + Ω, ∂x, ∂y = _Ω_∂_xlogy(x, y) |
| 34 | + xlogy_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), ∂x * ΔΩ, ∂y * ΔΩ) |
| 35 | + return Ω, xlogy_pullback |
| 36 | +end |
| 37 | + |
| 38 | +function _Ω_∂_xlog1py(x::Real, y::Real) |
| 39 | + log1py = log1p(y) |
| 40 | + z = x * log1py |
| 41 | + Ω = iszero(x) && !isnan(y) ? zero(z) : z |
| 42 | + ∂x = log1py |
| 43 | + ∂y = x / (1 + y) |
| 44 | + return Ω, ∂x, ∂y |
| 45 | +end |
| 46 | +function ChainRulesCore.frule((_, Δx, Δy), ::typeof(xlog1py), x::Real, y::Real) |
| 47 | + Ω, ∂x, ∂y = _Ω_∂_xlog1py(x, y) |
| 48 | + ΔΩ = muladd(∂x, Δx, ∂y * Δy) |
| 49 | + return Ω, ΔΩ |
| 50 | +end |
| 51 | +function ChainRulesCore.rrule(::typeof(xlog1py), x::Real, y::Real) |
| 52 | + Ω, ∂x, ∂y = _Ω_∂_xlog1py(x, y) |
| 53 | + xlog1py_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), ∂x * ΔΩ, ∂y * ΔΩ) |
| 54 | + return Ω, xlog1py_pullback |
| 55 | +end |
| 56 | + |
| 57 | +function _Ω_∂_xexpx(x::Real) |
6 | 58 | expx = exp(x)
|
7 | 59 | if iszero(expx)
|
8 | 60 | Ω = expx
|
9 |
| - ΔΩ = expx * Δx |
| 61 | + ∂x = expx |
10 | 62 | else
|
11 | 63 | Ω = x * expx
|
12 |
| - ΔΩ = (1 + x) * expx * Δx |
| 64 | + ∂x = (1 + x) * expx |
13 | 65 | end
|
| 66 | + return Ω, ∂x |
| 67 | +end |
| 68 | +function ChainRulesCore.frule((_, Δx), ::typeof(xexpx), x::Real) |
| 69 | + Ω, ∂x = _Ω_∂_xexpx(x) |
| 70 | + ΔΩ = ∂x * Δx |
14 | 71 | return Ω, ΔΩ
|
15 | 72 | end
|
16 |
| - |
17 | 73 | function ChainRulesCore.rrule(::typeof(xexpx), x::Real)
|
18 |
| - expx = exp(x) |
19 |
| - Ω = iszero(expx) ? expx : x * expx |
20 |
| - function xexpx_pullback(ΔΩ) |
21 |
| - Δx = iszero(expx) ? expx * ΔΩ : (1 + x) * expx * ΔΩ |
22 |
| - return (ChainRulesCore.NoTangent(), Δx) |
23 |
| - end |
| 74 | + Ω, ∂x = _Ω_∂_xexpx(x) |
| 75 | + xexpx_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), ∂x * ΔΩ) |
24 | 76 | return Ω, xexpx_pullback
|
25 | 77 | end
|
26 | 78 |
|
27 |
| -function ChainRulesCore.frule((_, Δx, Δy), ::typeof(xexpy), x::Real, y::Real) |
| 79 | +function _Ω_∂_xexpy(x::Real, y::Real) |
28 | 80 | expy = exp(y)
|
29 | 81 | result = x * expy
|
30 | 82 | Ω = iszero(expy) && !isnan(x) ? zero(result) : result
|
31 |
| - ΔΩ = expy * Δx + Ω * Δy |
| 83 | + ∂x = expy |
| 84 | + ∂y = Ω |
| 85 | + return Ω, ∂x, ∂y |
| 86 | +end |
| 87 | +function ChainRulesCore.frule((_, Δx, Δy), ::typeof(xexpy), x::Real, y::Real) |
| 88 | + Ω, ∂x, ∂y = _Ω_∂_xexpy(x, y) |
| 89 | + ΔΩ = muladd(∂x, Δx, ∂y * Δy) |
32 | 90 | return Ω, ΔΩ
|
33 | 91 | end
|
34 |
| - |
35 | 92 | function ChainRulesCore.rrule(::typeof(xexpy), x::Real, y::Real)
|
36 |
| - expy = exp(y) |
37 |
| - result = x * expy |
38 |
| - Ω = iszero(expy) && !isnan(x) ? zero(result) : result |
39 |
| - xexpy_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), ΔΩ * expy, ΔΩ * Ω) |
| 93 | + Ω, ∂x, ∂y = _Ω_∂_xexpy(x, y) |
| 94 | + xexpy_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), ∂x * ΔΩ, ∂y * ΔΩ) |
40 | 95 | return Ω, xexpy_pullback
|
41 | 96 | end
|
42 | 97 |
|
|
0 commit comments