Skip to content

Commit f870554

Browse files
Simplify fastpow implementation via T output
1 parent 7fc48c1 commit f870554

File tree

2 files changed

+13
-94
lines changed

2 files changed

+13
-94
lines changed

ext/DiffEqBaseEnzymeExt.jl

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -53,46 +53,6 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}
5353
return ntuple(_ -> nothing, Val(length(args) + 4))
5454
end
5555

56-
function Enzyme.EnzymeRules.forward(func::Const{typeof(DiffEqBase.fastpow)},
57-
RT::Type{<:Union{Duplicated, DuplicatedNoNeed}},
58-
_x::Union{Const, Duplicated}, _y::Union{Const, Duplicated})
59-
x = _x.val
60-
y = _y.val
61-
ret = func.val(x, y)
62-
if !(_x isa Const)
63-
dxval = _x.dval * y * (fastpow(x,y - 1))
64-
else
65-
dxval = make_zero(_x.val)
66-
end
67-
if !(_y isa Const)
68-
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : _y.dval*(fastpow(x,y))*log(x)
69-
else
70-
dyval = make_zero(_y.val)
71-
end
72-
if RT <: DuplicatedNoNeed
73-
return Float32(dxval + dyval)
74-
else
75-
return Duplicated(ret, Float32(dxval + dyval))
76-
end
77-
end
78-
79-
function EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWidth{1},
80-
func::Const{typeof(fastpow)}, ::Type{<:Active}, x::Active, y::Active)
81-
if EnzymeRules.needs_primal(config)
82-
primal = func.val(x.val, y.val)
83-
else
84-
primal = nothing
85-
end
86-
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
87-
end
88-
89-
function EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1},
90-
func::Const{typeof(DiffEqBase.fastpow)}, dret::Active, tape, _x::Active, _y::Active)
91-
x = _x.val
92-
y = _y.val
93-
dxval = dret.val * y * (fastpow(x,y - 1))
94-
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : dret.val * (fastpow(x,y))*log(x)
95-
return (dxval, dyval)
96-
end
56+
Enzyme.Compiler.known_ops[typeof(DiffEqBase.fastpow)] = (:pow, 2, nothing)
9757

9858
end

src/fastpow.jl

Lines changed: 12 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -51,60 +51,19 @@ const EXP2FT = (Float32(0x1.6a09e667f3bcdp-1),
5151
Float32(0x1.3dea64c123422p+0),
5252
Float32(0x1.4bfdad5362a27p+0),
5353
Float32(0x1.5ab07dd485429p+0))
54-
@inline function _exp2(x::Float32)
55-
TBLBITS = UInt32(4)
56-
TBLSIZE = UInt32(1 << TBLBITS)
5754

58-
redux = Float32(0x1.8p23) / TBLSIZE
59-
P1 = Float32(0x1.62e430p-1)
60-
P2 = Float32(0x1.ebfbe0p-3)
61-
P3 = Float32(0x1.c6b348p-5)
62-
P4 = Float32(0x1.3b2c9cp-7)
63-
64-
# Reduce x, computing z, i0, and k.
65-
t::Float32 = x + redux
66-
i0 = reinterpret(UInt32, t)
67-
i0 += TBLSIZE ÷ UInt32(2)
68-
k::UInt32 = unsafe_trunc(UInt32, (i0 >> TBLBITS) << 20)
69-
i0 &= TBLSIZE - UInt32(1)
70-
t -= redux
71-
z = x - t
72-
twopk = Float32(reinterpret(Float64, UInt64(0x3ff00000 + k) << 32))
73-
74-
# Compute r = exp2(y) = exp2ft[i0] * p(z).
75-
tv = EXP2FT[i0 + UInt32(1)]
76-
u = tv * z
77-
tv = tv + u * (P1 + z * P2) + u * (z * z) * (P3 + z * P4)
78-
79-
# Scale by 2**(k>>20)
80-
return tv * twopk
81-
end
82-
83-
if VERSION < v"1.7.0"
84-
"""
85-
fastpow(x::Real, y::Real) -> Float32
86-
"""
87-
@inline function fastpow(x::Real, y::Real)
88-
if iszero(x)
89-
return 0.0f0
90-
elseif isinf(x) && isinf(y)
91-
return Float32(Inf)
92-
else
93-
return _exp2(convert(Float32, y) * fastlog2(convert(Float32, x)))
94-
end
95-
end
96-
else
97-
"""
98-
fastpow(x::Real, y::Real) -> Float32
99-
"""
100-
@inline function fastpow(x::Real, y::Real)
101-
if iszero(x)
102-
return 0.0f0
103-
elseif isinf(x) && isinf(y)
104-
return Float32(Inf)
105-
else
106-
return @fastmath exp2(convert(Float32, y) * fastlog2(convert(Float32, x)))
107-
end
55+
"""
56+
fastpow(x::T, y::T) where {T} -> T
57+
Trips through Float32 for performance.
58+
"""
59+
@inline function fastpow(x::T, y::T) where {T}
60+
if iszero(x)
61+
return zero(T)
62+
elseif isinf(x) && isinf(y)
63+
return convert(T,Inf)
64+
else
65+
return convert(T,@fastmath exp2(convert(Float32, y) * fastlog2(convert(Float32, x))))
10866
end
10967
end
68+
11069
@inline fastpow(x, y) = x^y

0 commit comments

Comments
 (0)