|
17 | 17 | p = 1.1x + y
|
18 | 18 | q = (-0.1 + im) * x - y
|
19 | 19 |
|
| 20 | + output, pullback = ChainRulesCore.rrule(+, q) |
| 21 | + @test output == q |
| 22 | + @test pullback(2) == (NoTangent(), 2) |
| 23 | + @test pullback(x + 3) == (NoTangent(), x + 3) |
| 24 | + |
| 25 | + output, pullback = ChainRulesCore.rrule(-, q) |
| 26 | + @test output ≈ -q |
| 27 | + @test pullback(2) == (NoTangent(), -2) |
| 28 | + @test pullback(x + 3im) == (NoTangent(), -x - 3im) |
| 29 | + |
20 | 30 | output, pullback = ChainRulesCore.rrule(+, p, q)
|
21 | 31 | @test output == (1.0 + im)x
|
22 | 32 | @test pullback(2) == (NoTangent(), 2, 2)
|
|
25 | 35 | output, pullback = ChainRulesCore.rrule(-, p, q)
|
26 | 36 | @test output ≈ (1.2 - im) * x + 2y
|
27 | 37 | @test pullback(2) == (NoTangent(), 2, -2)
|
28 |
| - @test pullback(x + 3) == (NoTangent(), x + 3, -x - 3) |
| 38 | + @test pullback(im * x + 3) == (NoTangent(), im * x + 3, -im * x - 3) |
29 | 39 |
|
30 | 40 | output, pullback = ChainRulesCore.rrule(differentiate, p, x)
|
31 | 41 | @test output == 1.1
|
32 | 42 | @test pullback(q) == (NoTangent(), (-0.2 + 2im) * x^2 - x*y, NoTangent())
|
33 | 43 | @test pullback(1x) == (NoTangent(), 2x^2, NoTangent())
|
34 | 44 |
|
| 45 | + test_chain_rule(dot, +, (p,), (q,), p) |
| 46 | + test_chain_rule(dot, +, (q,), (p,), q) |
| 47 | + |
| 48 | + test_chain_rule(dot, -, (p,), (q,), p) |
| 49 | + test_chain_rule(dot, -, (p,), (p,), q) |
| 50 | + |
35 | 51 | test_chain_rule(dot, +, (p, q), (q, p), p)
|
36 | 52 | test_chain_rule(dot, +, (p, q), (p, q), q)
|
37 | 53 |
|
|
0 commit comments