Skip to content

Commit 5852794

Browse files
Merge pull request #1072 from SciML/fastpow
Add Enzyme support for fastpow
2 parents 869227b + 192bb2f commit 5852794

File tree

5 files changed

+45
-65
lines changed

5 files changed

+45
-65
lines changed

ext/DiffEqBaseEnzymeExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module DiffEqBaseEnzymeExt
22

33
using DiffEqBase
4-
import DiffEqBase: value
4+
import DiffEqBase: value, fastpow
55
using Enzyme
66
import Enzyme: Const
77
using ChainRulesCore
@@ -53,4 +53,6 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}
5353
return ntuple(_ -> nothing, Val(length(args) + 4))
5454
end
5555

56-
end
56+
Enzyme.Compiler.known_ops[typeof(DiffEqBase.fastpow)] = (:pow, 2, nothing)
57+
58+
end

src/fastpow.jl

Lines changed: 13 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -51,60 +51,20 @@ const EXP2FT = (Float32(0x1.6a09e667f3bcdp-1),
5151
Float32(0x1.3dea64c123422p+0),
5252
Float32(0x1.4bfdad5362a27p+0),
5353
Float32(0x1.5ab07dd485429p+0))
54-
@inline function _exp2(x::Float32)
55-
TBLBITS = UInt32(4)
56-
TBLSIZE = UInt32(1 << TBLBITS)
5754

58-
redux = Float32(0x1.8p23) / TBLSIZE
59-
P1 = Float32(0x1.62e430p-1)
60-
P2 = Float32(0x1.ebfbe0p-3)
61-
P3 = Float32(0x1.c6b348p-5)
62-
P4 = Float32(0x1.3b2c9cp-7)
63-
64-
# Reduce x, computing z, i0, and k.
65-
t::Float32 = x + redux
66-
i0 = reinterpret(UInt32, t)
67-
i0 += TBLSIZE ÷ UInt32(2)
68-
k::UInt32 = unsafe_trunc(UInt32, (i0 >> TBLBITS) << 20)
69-
i0 &= TBLSIZE - UInt32(1)
70-
t -= redux
71-
z = x - t
72-
twopk = Float32(reinterpret(Float64, UInt64(0x3ff00000 + k) << 32))
73-
74-
# Compute r = exp2(y) = exp2ft[i0] * p(z).
75-
tv = EXP2FT[i0 + UInt32(1)]
76-
u = tv * z
77-
tv = tv + u * (P1 + z * P2) + u * (z * z) * (P3 + z * P4)
78-
79-
# Scale by 2**(k>>20)
80-
return tv * twopk
81-
end
82-
83-
if VERSION < v"1.7.0"
84-
"""
85-
fastpow(x::Real, y::Real) -> Float32
86-
"""
87-
@inline function fastpow(x::Real, y::Real)
88-
if iszero(x)
89-
return 0.0f0
90-
elseif isinf(x) && isinf(y)
91-
return Float32(Inf)
92-
else
93-
return _exp2(convert(Float32, y) * fastlog2(convert(Float32, x)))
94-
end
95-
end
96-
else
97-
"""
98-
fastpow(x::Real, y::Real) -> Float32
99-
"""
100-
@inline function fastpow(x::Real, y::Real)
101-
if iszero(x)
102-
return 0.0f0
103-
elseif isinf(x) && isinf(y)
104-
return Float32(Inf)
105-
else
106-
return @fastmath exp2(convert(Float32, y) * fastlog2(convert(Float32, x)))
107-
end
55+
"""
56+
fastpow(x::T, y::T) where {T} -> float(T)
57+
Trips through Float32 for performance.
58+
"""
59+
@inline function fastpow(x::T, y::T) where {T}
60+
outT = float(T)
61+
if iszero(x)
62+
return zero(outT)
63+
elseif isinf(x) && isinf(y)
64+
return convert(outT,Inf)
65+
else
66+
return convert(outT,@fastmath exp2(convert(Float32, y) * fastlog2(convert(Float32, x))))
10867
end
10968
end
69+
11070
@inline fastpow(x, y) = x^y

test/downstream/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
33
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
44
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
55
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
6+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
67
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
78
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
89
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

test/downstream/enzyme.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using Enzyme, EnzymeTestUtils
2+
using DiffEqBase: fastlog2, fastpow
3+
using Test
4+
5+
@testset "Fast pow - Enzyme forward rule" begin
6+
@testset for RT in (Duplicated, DuplicatedNoNeed),
7+
Tx in (Const, Duplicated),
8+
Ty in (Const, Duplicated)
9+
x = 3.0
10+
y = 2.0
11+
test_forward(fastpow, RT, (x, Tx), (y, Ty), atol=0.005, rtol=0.005)
12+
end
13+
end
14+
15+
@testset "Fast pow - Enzyme reverse rule" begin
16+
@testset for RT in (Active,),
17+
Tx in (Active,),
18+
Ty in (Active,)
19+
x = 2.0
20+
y = 3.0
21+
test_reverse(fastpow, RT, (x, Tx), (y, Ty), atol=0.001, rtol=0.001)
22+
end
23+
end

test/fastpow.jl

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using DiffEqBase: fastlog2, _exp2, fastpow
1+
using DiffEqBase: fastlog2, fastpow
22
using Test
33

44
@testset "Fast log2" begin
@@ -7,15 +7,9 @@ using Test
77
end
88
end
99

10-
@testset "Exp2" begin
11-
for x in -100:0.01:3
12-
@test exp2(x)_exp2(Float32(x)) atol=1e-6
13-
end
14-
end
15-
1610
@testset "Fast pow" begin
17-
@test fastpow(1, 1) isa Float32
18-
@test fastpow(1.0, 1.0) isa Float32
11+
@test fastpow(1, 1) isa Float64
12+
@test fastpow(1.0, 1.0) isa Float64
1913
errors = [abs(^(x, y) - fastpow(x, y)) for x in 0.001:0.001:1, y in 0.08:0.001:0.5]
2014
@test maximum(errors) < 1e-4
21-
end
15+
end

0 commit comments

Comments
 (0)