Skip to content

Commit 5a4fedd

Browse files
committed
Try to fix on older julia
1 parent 81984ed commit 5a4fedd

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

ext/SpecialFunctionsChainRulesCoreExt.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ function ChainRulesCore.frule((_, Δa, Δb, Δx), ::typeof(beta_inc), a::Number,
636636
# derivatives
637637
T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)))
638638
_, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x))
639-
Δp = dIa * T(Δa) + dIb * T(Δb) + dIx * T(Δx)
639+
Δp = dIa * convert(T, Δa) + dIb * convert(T, Δb) + dIx * convert(T, Δx)
640640
Δq = -Δp
641641
Tout = typeof((p, q))
642642
return (p, q), ChainRulesCore.Tangent{Tout}(Δp, Δq)
@@ -651,7 +651,7 @@ function ChainRulesCore.rrule(::typeof(beta_inc), a::Number, b::Number, x::Numbe
651651
_, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x))
652652
function beta_inc_pullback(Δ)
653653
Δp, Δq = Δ
654-
s = T(Δp) - T(Δq) # because q = 1 - p
654+
s = Δp - Δq # because q = 1 - p
655655
= Ta(s * dIa)
656656
= Tb(s * dIb)
657657
= Tx(s * dIx)
@@ -663,7 +663,7 @@ function ChainRulesCore.frule((_, Δa, Δb, Δx, Δy), ::typeof(beta_inc), a::Nu
663663
p, q = beta_inc(a, b, x, y)
664664
T = promote_type(float(typeof(a)), float(typeof(b)), float(typeof(x)), float(typeof(y)))
665665
_, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x))
666-
Δp = dIa * T(Δa) + dIb * T(Δb) + dIx * (T(Δx) - T(Δy))
666+
Δp = dIa * convert(T, Δa) + dIb * convert(T, Δb) + dIx * (convert(T, Δx) - convert(T, Δy))
667667
Δq = -Δp
668668
Tout = typeof((p, q))
669669
return (p, q), ChainRulesCore.Tangent{Tout}(Δp, Δq)
@@ -679,7 +679,7 @@ function ChainRulesCore.rrule(::typeof(beta_inc), a::Number, b::Number, x::Numbe
679679
_, dIa, dIb, dIx = _ibeta_grad_splus(T(a), T(b), T(x))
680680
function beta_inc_pullback(Δ)
681681
Δp, Δq = Δ
682-
s = T(Δp) - T(Δq)
682+
s = Δp - Δq
683683
= Ta(s * dIa)
684684
= Tb(s * dIb)
685685
= Tx(s * dIx)
@@ -701,7 +701,7 @@ function ChainRulesCore.frule((_, Δa, Δb, Δp), ::typeof(beta_inc_inv), a::Num
701701
dx_da = -dIa * inv_dIx
702702
dx_db = -dIb * inv_dIx
703703
dx_dp = inv_dIx
704-
Δx = dx_da * T(Δa) + dx_db * T(Δb) + dx_dp * T(Δp)
704+
Δx = dx_da * convert(T, Δa) + dx_db * convert(T, Δb) + dx_dp * convert(T, Δp)
705705
Δy = -Δx
706706
Tout = typeof((x, y))
707707
return (x, y), ChainRulesCore.Tangent{Tout}(Δx, Δy)
@@ -722,7 +722,7 @@ function ChainRulesCore.rrule(::typeof(beta_inc_inv), a::Number, b::Number, p::N
722722
dx_dp = inv_dIx
723723
function beta_inc_inv_pullback(Δ)
724724
Δx, Δy = Δ
725-
s = T(Δx) - T(Δy)
725+
s = Δx - Δy
726726
= Ta(s * dx_da)
727727
= Tb(s * dx_db)
728728
= Tp(s * dx_dp)

0 commit comments

Comments
 (0)