Skip to content

Commit dbdcad5

Browse files
kshyattKatharine Hyatt
andauthored
Add easy_rule for BigFloat division (#2934)
* Add easy_rule for BigFloat division * Typo fix * Fix? * Add in multiplication too * More rules and tests and mixed types * Do inv too * Add simple trig functions * Fixed tests w autodiff * Don't use rand * Fix runtests.jl --------- Co-authored-by: Katharine Hyatt <[email protected]>
1 parent 46d5a1d commit dbdcad5

File tree

2 files changed

+37
-14
lines changed

2 files changed

+37
-14
lines changed

src/internal_rules/bigfloat.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ function EnzymeRules.forward(
77

88
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
99
if EnzymeRules.width(config) == 1
10-
return RT(Ty.val(; kwargs...), Ty.val(; kwargs...))
10+
return remove_innerty(RT)(Ty.val(; kwargs...), Ty.val(; kwargs...))
1111
else
1212
tup = ntuple(Val(EnzymeRules.width(config))) do i
1313
Base.@_inline_meta
1414
Ty.val(; kwargs...)
1515
end
16-
return RT(Ty.val(; kwargs...), tup)
16+
return remove_innerty(RT)(Ty.val(; kwargs...), tup)
1717
end
1818
elseif EnzymeRules.needs_shadow(config)
1919
if EnzymeRules.width(config) == 1
@@ -67,5 +67,19 @@ function EnzymeRules.reverse(
6767
return ()
6868
end
6969

70+
EnzymeRules.@easy_rule(+(a::BigFloat, b::Number), (1,1))
71+
EnzymeRules.@easy_rule(+(a::Number, b::BigFloat), (1,1))
7072
EnzymeRules.@easy_rule(+(a::BigFloat, b::BigFloat), (1,1))
73+
EnzymeRules.@easy_rule(-(a::BigFloat, b::Number), (1,-1))
74+
EnzymeRules.@easy_rule(-(a::Number, b::BigFloat), (1,-1))
7175
EnzymeRules.@easy_rule(-(a::BigFloat, b::BigFloat), (1,-1))
76+
EnzymeRules.@easy_rule(*(a::BigFloat, b::BigFloat), (b, a))
77+
EnzymeRules.@easy_rule(*(a::BigFloat, b::Number), (b, a))
78+
EnzymeRules.@easy_rule(*(a::Number, b::BigFloat), (b, a))
79+
EnzymeRules.@easy_rule(/(a::BigFloat, b::Number), (one(a)/b, -(a/b^2)))
80+
EnzymeRules.@easy_rule(/(a::Number, b::BigFloat), (one(a)/b, -(a/b^2)))
81+
EnzymeRules.@easy_rule(/(a::BigFloat, b::BigFloat), (one(a)/b, -(a/b^2)))
82+
EnzymeRules.@easy_rule(Base.inv(a::BigFloat), (-(one(a)/a^2),))
83+
EnzymeRules.@easy_rule(Base.sin(a::BigFloat), (cos(a),))
84+
EnzymeRules.@easy_rule(Base.cos(a::BigFloat), (-sin(a),))
85+
EnzymeRules.@easy_rule(Base.tan(a::BigFloat), (one(a) + Ω^2,))

test/rules/internal_rules/bigfloat.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,27 @@ using EnzymeTestUtils
33
using FiniteDifferences
44
using Test
55

6-
@testset "BigFloat +/-" begin
7-
a = rand(BigFloat)
8-
b = rand(BigFloat)
6+
@testset "BigFloat arithmetic" begin
7+
a = BigFloat(1.234)
8+
da = BigFloat(-0.23)
9+
b = BigFloat(0.56)
10+
db = BigFloat(0.27)
11+
af64 = 1.234 # for testing mixed methods
12+
daf64 = -0.23 # for testing mixed methods
13+
bf64 = 0.56 # for testing mixed methods
14+
dbf64 = 0.27 # for testing mixed methods
915

10-
# doesn't work because of https://github.com/EnzymeAD/Enzyme.jl/issues/2888
11-
#test_reverse(+, Const, (a, Const), (b, Const))
12-
#test_reverse(+, Active, (a, Active), (b, Active))
13-
#test_reverse(-, Const, (a, Const), (b, Const))
14-
#test_reverse(-, Active, (a, Active), (b, Active))
16+
@test autodiff(Enzyme.Forward, +, Duplicated, Duplicated(a, da), Duplicated(b, db))[:1] da+db
17+
@test autodiff(Enzyme.Forward, +, Duplicated, Duplicated(a, da), Duplicated(bf64, dbf64))[:1] da+dbf64
18+
@test autodiff(Enzyme.Forward, -, Duplicated, Duplicated(a, da), Duplicated(b, db))[:1] da-db
19+
@test autodiff(Enzyme.Forward, -, Duplicated, Duplicated(a, da), Duplicated(bf64, dbf64))[:1] da-dbf64
20+
@test autodiff(Enzyme.Forward, *, Duplicated, Duplicated(a, da), Duplicated(b, db))[:1] b*da + a*db
21+
@test autodiff(Enzyme.Forward, *, Duplicated, Duplicated(a, da), Duplicated(bf64, dbf64))[:1] bf64*da + a*dbf64
22+
@test autodiff(Enzyme.Forward, /, Duplicated, Duplicated(a, da), Duplicated(b, db))[:1] da/b - db * a/b^2
23+
@test autodiff(Enzyme.Forward, /, Duplicated, Duplicated(a, da), Duplicated(bf64, dbf64))[:1] da/bf64 - dbf64 * a/bf64^2
1524

16-
test_forward(+, Const, (a, Const), (b, Const))
17-
test_forward(+, Duplicated, (a, Duplicated), (b, Duplicated))
18-
test_forward(-, Const, (a, Const), (b, Const))
19-
test_forward(-, Duplicated, (a, Duplicated), (b, Duplicated))
25+
@test autodiff(Enzyme.Forward, inv, Duplicated, Duplicated(a, da))[:1] -(one(BigFloat)/a^2) * da
26+
@test autodiff(Enzyme.Forward, sin, Duplicated, Duplicated(a, da))[:1] cos(a) * da
27+
@test autodiff(Enzyme.Forward, cos, Duplicated, Duplicated(a, da))[:1] -sin(a) * da
28+
@test autodiff(Enzyme.Forward, tan, Duplicated, Duplicated(a, da))[:1] autodiff(Enzyme.Forward, tan, Duplicated, Duplicated(af64, daf64))[1]
2029
end

0 commit comments

Comments
 (0)