Skip to content

Commit a6ba8fa

Browse files
authored
Issue 313 (#475)
* WIP * add chainrules frule, rrule * correct merge
1 parent ea0323d commit a6ba8fa

File tree

5 files changed

+33
-2
lines changed

5 files changed

+33
-2
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,19 @@ author = "JuliaMath"
55
version = "3.2.5"
66

77
[deps]
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b"
1011
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1112

1213
[compat]
13-
RecipesBase = "0.7, 0.8, 1"
14+
ChainRulesCore = "1"
1415
MakieCore = "0.6"
16+
RecipesBase = "0.7, 0.8, 1"
1517
julia = "1.6"
1618

1719
[extras]
20+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
1821
DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
1922
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2023
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"

src/Polynomials.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ include("rational-functions/plot-recipes.jl")
3535

3636
# compat; opt-in with `using Polynomials.PolyCompat`
3737
include("polynomials/Poly.jl")
38-
38+
include("chain_rules.jl")
3939
include("precompiles.jl")
4040

4141
end # module

src/chain_rules.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import ChainRulesCore
2+
3+
function ChainRulesCore.frule(
4+
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
5+
(_, Δx),
6+
p::AbstractPolynomial,
7+
x
8+
)
9+
p(x), derivative(p)(x)*Δx
10+
end
11+
12+
13+
function ChainRulesCore.rrule(p::AbstractPolynomial, x)
14+
_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), derivative(p)(x))
15+
return (p(x), _pullback)
16+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
23
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
34
MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"
45
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"

test/StandardBasis.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,3 +1603,14 @@ end
16031603
@test Polynomials.minimumexponent(LaurentPolynomial{Float64}) == typemin(Int)
16041604
@test Polynomials.minimumexponent(LaurentPolynomial{Float64, :y}) == typemin(Int)
16051605
end
1606+
1607+
1608+
# Chain rules
1609+
using ChainRulesTestUtils
1610+
1611+
@testset "Test frule and rrule" begin
1612+
p = Polynomial([1,2,3,4])
1613+
dp = derivative(p)
1614+
1615+
test_scalar(p, 1.0; check_inferred=true)
1616+
end

0 commit comments

Comments
 (0)