Skip to content

Commit df2a49a

Browse files
moblegiordano
andauthored
Add rules for hypot (#2626)
* Add rules for hypot * Don't bother storing hypot value in tape * Format with runic * Deal with complex values * Reduce duplication by rearranging if statements * Avoid NaNs at the origin * Change hypot rules to apply with at least 3 arguments --------- Co-authored-by: Mosè Giordano <[email protected]>
1 parent 442db5b commit df2a49a

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed

src/internal_rules.jl

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,3 +1647,107 @@ function EnzymeRules.reverse(
16471647
dst::Annotation{<:AbstractArray})
16481648
return (nothing, nothing)
16491649
end
1650+
1651+
1652+
_hypotforward(x::Const) = zero(x.val)
1653+
_hypotforward(x) = real(conj(x.val) * x.dval)
1654+
_hypotforward(x::Const, i) = zero(x.val)
1655+
_hypotforward(x, i) = real(conj(x.val) * x.dval[i])
1656+
1657+
function EnzymeRules.forward(
1658+
config::EnzymeRules.FwdConfig,
1659+
func::Const{typeof(Base.hypot)},
1660+
RT,
1661+
x::Annotation,
1662+
y::Annotation,
1663+
z::Annotation,
1664+
xs::Vararg{Annotation, N}
1665+
) where {N}
1666+
if EnzymeRules.needs_shadow(config)
1667+
h = func.val(x.val, y.val, z.val, map(x -> x.val, xs)...)
1668+
n = iszero(h) ? one(h) : h
1669+
if EnzymeRules.width(config) == 1
1670+
dh = (
1671+
_hypotforward(x) +
1672+
_hypotforward(y) +
1673+
_hypotforward(z) +
1674+
sum(_hypotforward, xs, init = zero(real(x.val)))
1675+
) / n
1676+
if EnzymeRules.needs_primal(config)
1677+
return Duplicated(h, dh)
1678+
else
1679+
return dh
1680+
end
1681+
else
1682+
dh = ntuple(
1683+
i -> (
1684+
_hypotforward(x, i) +
1685+
_hypotforward(y, i) +
1686+
_hypotforward(z, i) +
1687+
sum(x -> _hypotforward(x, i), xs; init = zero(real(x.val)))
1688+
) / n,
1689+
Val(EnzymeRules.width(config)),
1690+
)
1691+
if EnzymeRules.needs_primal(config)
1692+
return BatchDuplicated(h, dh)
1693+
else
1694+
return dh
1695+
end
1696+
end
1697+
elseif EnzymeRules.needs_primal(config)
1698+
return func.val(x.val, y.val, z.val, map(x -> x.val, xs)...)
1699+
else
1700+
return nothing
1701+
end
1702+
end
1703+
1704+
_hypotreverse(x::Const, ::Val{W}, dret::Const, h) where {W} = nothing
1705+
_hypotreverse(x::Const, ::Val{W}, dret, h) where {W} = nothing
1706+
function _hypotreverse(x, w::Val{W}, dret::Const, h) where {W}
1707+
if W == 1
1708+
return zero(x.val)
1709+
else
1710+
return ntuple(Returns(zero(x.val)), w)
1711+
end
1712+
end
1713+
function _hypotreverse(x, w::Val{W}, dret, h) where {W}
1714+
if W == 1
1715+
return x.val * dret.val / h
1716+
else
1717+
return ntuple(i -> x.val * dret.val[i] / h, w)
1718+
end
1719+
end
1720+
1721+
function EnzymeRules.augmented_primal(
1722+
config::EnzymeRules.RevConfig,
1723+
func::Const{typeof(Base.hypot)},
1724+
::Type,
1725+
x::Annotation,
1726+
y::Annotation,
1727+
z::Annotation,
1728+
xs::Vararg{Annotation, N}
1729+
) where {N}
1730+
h = hypot(x.val, y.val, z.val, map(x -> x.val, xs)...)
1731+
primal = needs_primal(config) ? h : nothing
1732+
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
1733+
end
1734+
1735+
function EnzymeRules.reverse(
1736+
config::EnzymeRules.RevConfig,
1737+
func::Const{typeof(Base.hypot)},
1738+
dret,
1739+
tape,
1740+
x::Annotation,
1741+
y::Annotation,
1742+
z::Annotation,
1743+
xs::Vararg{Annotation, N}
1744+
) where {N}
1745+
h = hypot(x.val, y.val, z.val, map(x -> x.val, xs)...)
1746+
n = iszero(h) ? one(h) : h
1747+
w = Val(EnzymeRules.width(config))
1748+
dx = _hypotreverse(x, w, dret, n)
1749+
dy = _hypotreverse(y, w, dret, n)
1750+
dz = _hypotreverse(z, w, dret, n)
1751+
dxs = map(x -> _hypotreverse(x, w, dret, n), xs)
1752+
return (dx, dy, dz, dxs...)
1753+
end

test/rules/internal_rules.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,4 +830,44 @@ end
830830

831831
end
832832

833+
834+
@testset "hypot rules" begin
835+
@testset "forward" begin
836+
@testset for RT in (Const, DuplicatedNoNeed, Duplicated),
837+
Tx in (Const, Duplicated),
838+
Ty in (Const, Duplicated),
839+
Tz in (Const, Duplicated),
840+
Txs in (Const, Duplicated)
841+
842+
x, y, z, xs = 2.0, 3.0, 5.0, 17.0
843+
test_forward(hypot, RT, (x, Tx), (y, Ty))
844+
test_forward(hypot, RT, (x, Tx), (y, Ty), (z, Tz))
845+
test_forward(hypot, RT, (x, Tx), (y, Ty), (z, Tz), (xs, Txs))
846+
847+
x, y, z, xs = 2.0 + 7.0im, 3.0 + 11.0im, 5.0 + 13.0im, 17.0 + 19.0im
848+
test_forward(hypot, RT, (x, Tx), (y, Ty))
849+
test_forward(hypot, RT, (x, Tx), (y, Ty), (z, Tz))
850+
test_forward(hypot, RT, (x, Tx), (y, Ty), (z, Tz), (xs, Txs))
851+
end
852+
end
853+
@testset "reverse" begin
854+
@testset for RT in (Active,),
855+
Tx in (Const, Active),
856+
Ty in (Const, Active),
857+
Tz in (Const, Active),
858+
Txs in (Const, Active)
859+
860+
x, y, z, xs = 2.0, 3.0, 5.0, 17.0
861+
test_reverse(hypot, RT, (x, Tx), (y, Ty))
862+
test_reverse(hypot, RT, (x, Tx), (y, Ty), (z, Tz))
863+
test_reverse(hypot, RT, (x, Tx), (y, Ty), (z, Tz), (xs, Txs))
864+
865+
x, y, z, xs = 2.0 + 7.0im, 3.0 + 11.0im, 5.0 + 13.0im, 17.0 + 19.0im
866+
test_reverse(hypot, RT, (x, Tx), (y, Ty))
867+
test_reverse(hypot, RT, (x, Tx), (y, Ty), (z, Tz))
868+
test_reverse(hypot, RT, (x, Tx), (y, Ty), (z, Tz), (xs, Txs))
869+
end
870+
end
871+
end
872+
833873
end # InternalRules

0 commit comments

Comments
 (0)