Skip to content

Commit 2c94023

Browse files
authored
Add chain rule for unary + and - (#206)
1 parent e79182a commit 2c94023

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/chain_rules.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import ChainRulesCore
22

3+
ChainRulesCore.@scalar_rule +(x::APL) true
4+
ChainRulesCore.@scalar_rule -(x::APL) -1
5+
36
ChainRulesCore.@scalar_rule +(x::APL, y::APL) (true, true)
47
ChainRulesCore.@scalar_rule -(x::APL, y::APL) (true, -1)
58

test/chain_rules.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@ end
1717
p = 1.1x + y
1818
q = (-0.1 + im) * x - y
1919

20+
output, pullback = ChainRulesCore.rrule(+, q)
21+
@test output == q
22+
@test pullback(2) == (NoTangent(), 2)
23+
@test pullback(x + 3) == (NoTangent(), x + 3)
24+
25+
output, pullback = ChainRulesCore.rrule(-, q)
26+
@test output -q
27+
@test pullback(2) == (NoTangent(), -2)
28+
@test pullback(x + 3im) == (NoTangent(), -x - 3im)
29+
2030
output, pullback = ChainRulesCore.rrule(+, p, q)
2131
@test output == (1.0 + im)x
2232
@test pullback(2) == (NoTangent(), 2, 2)
@@ -25,13 +35,19 @@ end
2535
output, pullback = ChainRulesCore.rrule(-, p, q)
2636
@test output (1.2 - im) * x + 2y
2737
@test pullback(2) == (NoTangent(), 2, -2)
28-
@test pullback(x + 3) == (NoTangent(), x + 3, -x - 3)
38+
@test pullback(im * x + 3) == (NoTangent(), im * x + 3, -im * x - 3)
2939

3040
output, pullback = ChainRulesCore.rrule(differentiate, p, x)
3141
@test output == 1.1
3242
@test pullback(q) == (NoTangent(), (-0.2 + 2im) * x^2 - x*y, NoTangent())
3343
@test pullback(1x) == (NoTangent(), 2x^2, NoTangent())
3444

45+
test_chain_rule(dot, +, (p,), (q,), p)
46+
test_chain_rule(dot, +, (q,), (p,), q)
47+
48+
test_chain_rule(dot, -, (p,), (q,), p)
49+
test_chain_rule(dot, -, (p,), (p,), q)
50+
3551
test_chain_rule(dot, +, (p, q), (q, p), p)
3652
test_chain_rule(dot, +, (p, q), (p, q), q)
3753

0 commit comments

Comments
 (0)