Skip to content

Commit b12e519

Browse files
authored
Add ChainRules support (#202)
* Add ChainRules support * Add chain rule for differentiate * Add chain rule for - * Add chain rule for *
1 parent 9c9c7a2 commit b12e519

File tree

5 files changed

+88
-0
lines changed

5 files changed

+88
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ repo = "https://github.com/JuliaAlgebra/MultivariatePolynomials.jl"
55
version = "0.4.4"
66

77
[deps]
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"

src/MultivariatePolynomials.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,6 @@ include("division.jl")
8888
include("gcd.jl")
8989
include("det.jl")
9090

91+
include("chain_rules.jl")
92+
9193
end # module

src/chain_rules.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import ChainRulesCore
2+
3+
ChainRulesCore.@scalar_rule +(x::APL, y::APL) (true, true)
4+
ChainRulesCore.@scalar_rule -(x::APL, y::APL) (true, -1)
5+
6+
function ChainRulesCore.frule((_, Δp, Δq), ::typeof(*), p::APL, q::APL)
7+
return p * q, MA.add_mul!!(p * Δq, q, Δp)
8+
end
9+
function ChainRulesCore.rrule(::typeof(*), p::APL, q::APL)
10+
function times_pullback2(ΔΩ̇)
11+
#ΔΩ = ChainRulesCore.unthunk(Ω̇)
12+
#return (ChainRulesCore.NoTangent(), ChainRulesCore.ProjectTo(p)(ΔΩ * q'), ChainRulesCore.ProjectTo(q)(p' * ΔΩ))
13+
return (ChainRulesCore.NoTangent(), ΔΩ̇ * q', p' * ΔΩ̇)
14+
end
15+
return p * q, times_pullback2
16+
end
17+
18+
function ChainRulesCore.frule((_, Δp, _), ::typeof(differentiate), p, x)
19+
return differentiate(p, x), differentiate(Δp, x)
20+
end
21+
function pullback(Δdpdx, x)
22+
return ChainRulesCore.NoTangent(), x * differentiate(x * Δdpdx, x), ChainRulesCore.NoTangent()
23+
end
24+
function ChainRulesCore.rrule(::typeof(differentiate), p, x)
25+
dpdx = differentiate(p, x)
26+
return dpdx, Base.Fix2(pullback, x)
27+
end

test/chain_rules.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using LinearAlgebra, Test
2+
using ChainRulesCore
3+
4+
function test_chain_rule(dot, op, args, Δin, Δout)
5+
output = op(args...)
6+
foutput, fΔout = ChainRulesCore.frule((NoTangent(), Δin...), op, args...)
7+
@test output == foutput
8+
routput, pullback = ChainRulesCore.rrule(op, args...)
9+
@test output == routput
10+
rΔin = pullback(Δout)
11+
@test rΔin[1] == NoTangent()
12+
@test dot(Δin, rΔin[2:end]) dot(fΔout, Δout)
13+
end
14+
15+
@testset "ChainRulesCore" begin
16+
Mod.@polyvar x y
17+
p = 1.1x + y
18+
q = (-0.1 + im) * x - y
19+
20+
output, pullback = ChainRulesCore.rrule(+, p, q)
21+
@test output == (1.0 + im)x
22+
@test pullback(2) == (NoTangent(), 2, 2)
23+
@test pullback(x + 3) == (NoTangent(), x + 3, x + 3)
24+
25+
output, pullback = ChainRulesCore.rrule(-, p, q)
26+
@test output (1.2 - im) * x + 2y
27+
@test pullback(2) == (NoTangent(), 2, -2)
28+
@test pullback(x + 3) == (NoTangent(), x + 3, -x - 3)
29+
30+
output, pullback = ChainRulesCore.rrule(differentiate, p, x)
31+
@test output == 1.1
32+
@test pullback(q) == (NoTangent(), (-0.2 + 2im) * x^2 - x*y, NoTangent())
33+
@test pullback(1x) == (NoTangent(), 2x^2, NoTangent())
34+
35+
test_chain_rule(dot, +, (p, q), (q, p), p)
36+
test_chain_rule(dot, +, (p, q), (p, q), q)
37+
38+
test_chain_rule(dot, -, (p, q), (q, p), p)
39+
test_chain_rule(dot, -, (p, q), (p, q), q)
40+
41+
test_chain_rule(dot, *, (p, q), (q, p), p * q)
42+
test_chain_rule(dot, *, (p, q), (p, q), q * q)
43+
test_chain_rule(dot, *, (q, p), (p, q), q * q)
44+
test_chain_rule(dot, *, (p, q), (q, p), q * q)
45+
46+
function _dot(p, q)
47+
monos = monomials(p + q)
48+
return dot(coefficient.(p, monos), coefficient.(q, monos))
49+
end
50+
function _dot(px::Tuple{<:AbstractPolynomial,NoTangent}, qx::Tuple{<:AbstractPolynomial,NoTangent})
51+
return _dot(px[1], qx[1])
52+
end
53+
test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), p)
54+
test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), differentiate(p, x))
55+
test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), differentiate(q, x))
56+
test_chain_rule(_dot, differentiate, (p, x), (p * q, NoTangent()), p)
57+
end

test/commutativetests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ include("comparison.jl")
2020
include("substitution.jl")
2121
include("differentiation.jl")
2222
include("division.jl")
23+
include("chain_rules.jl")
2324

2425
include("show.jl")
2526

0 commit comments

Comments
 (0)