@@ -1647,3 +1647,107 @@ function EnzymeRules.reverse(
16471647 dst:: Annotation{<:AbstractArray} )
16481648 return (nothing , nothing )
16491649end
1650+
1651+
1652+ _hypotforward (x:: Const ) = zero (x. val)
1653+ _hypotforward (x) = real (conj (x. val) * x. dval)
1654+ _hypotforward (x:: Const , i) = zero (x. val)
1655+ _hypotforward (x, i) = real (conj (x. val) * x. dval[i])
1656+
1657+ function EnzymeRules. forward (
1658+ config:: EnzymeRules.FwdConfig ,
1659+ func:: Const{typeof(Base.hypot)} ,
1660+ RT,
1661+ x:: Annotation ,
1662+ y:: Annotation ,
1663+ z:: Annotation ,
1664+ xs:: Vararg{Annotation, N}
1665+ ) where {N}
1666+ if EnzymeRules. needs_shadow (config)
1667+ h = func. val (x. val, y. val, z. val, map (x -> x. val, xs)... )
1668+ n = iszero (h) ? one (h) : h
1669+ if EnzymeRules. width (config) == 1
1670+ dh = (
1671+ _hypotforward (x) +
1672+ _hypotforward (y) +
1673+ _hypotforward (z) +
1674+ sum (_hypotforward, xs, init = zero (real (x. val)))
1675+ ) / n
1676+ if EnzymeRules. needs_primal (config)
1677+ return Duplicated (h, dh)
1678+ else
1679+ return dh
1680+ end
1681+ else
1682+ dh = ntuple (
1683+ i -> (
1684+ _hypotforward (x, i) +
1685+ _hypotforward (y, i) +
1686+ _hypotforward (z, i) +
1687+ sum (x -> _hypotforward (x, i), xs; init = zero (real (x. val)))
1688+ ) / n,
1689+ Val (EnzymeRules. width (config)),
1690+ )
1691+ if EnzymeRules. needs_primal (config)
1692+ return BatchDuplicated (h, dh)
1693+ else
1694+ return dh
1695+ end
1696+ end
1697+ elseif EnzymeRules. needs_primal (config)
1698+ return func. val (x. val, y. val, z. val, map (x -> x. val, xs)... )
1699+ else
1700+ return nothing
1701+ end
1702+ end
1703+
1704+ _hypotreverse (x:: Const , :: Val{W} , dret:: Const , h) where {W} = nothing
1705+ _hypotreverse (x:: Const , :: Val{W} , dret, h) where {W} = nothing
1706+ function _hypotreverse (x, w:: Val{W} , dret:: Const , h) where {W}
1707+ if W == 1
1708+ return zero (x. val)
1709+ else
1710+ return ntuple (Returns (zero (x. val)), w)
1711+ end
1712+ end
1713+ function _hypotreverse (x, w:: Val{W} , dret, h) where {W}
1714+ if W == 1
1715+ return x. val * dret. val / h
1716+ else
1717+ return ntuple (i -> x. val * dret. val[i] / h, w)
1718+ end
1719+ end
1720+
1721+ function EnzymeRules. augmented_primal (
1722+ config:: EnzymeRules.RevConfig ,
1723+ func:: Const{typeof(Base.hypot)} ,
1724+ :: Type ,
1725+ x:: Annotation ,
1726+ y:: Annotation ,
1727+ z:: Annotation ,
1728+ xs:: Vararg{Annotation, N}
1729+ ) where {N}
1730+ h = hypot (x. val, y. val, z. val, map (x -> x. val, xs)... )
1731+ primal = needs_primal (config) ? h : nothing
1732+ return EnzymeRules. AugmentedReturn (primal, nothing , nothing )
1733+ end
1734+
1735+ function EnzymeRules. reverse (
1736+ config:: EnzymeRules.RevConfig ,
1737+ func:: Const{typeof(Base.hypot)} ,
1738+ dret,
1739+ tape,
1740+ x:: Annotation ,
1741+ y:: Annotation ,
1742+ z:: Annotation ,
1743+ xs:: Vararg{Annotation, N}
1744+ ) where {N}
1745+ h = hypot (x. val, y. val, z. val, map (x -> x. val, xs)... )
1746+ n = iszero (h) ? one (h) : h
1747+ w = Val (EnzymeRules. width (config))
1748+ dx = _hypotreverse (x, w, dret, n)
1749+ dy = _hypotreverse (y, w, dret, n)
1750+ dz = _hypotreverse (z, w, dret, n)
1751+ dxs = map (x -> _hypotreverse (x, w, dret, n), xs)
1752+ return (dx, dy, dz, dxs... )
1753+ end
0 commit comments