Skip to content

Commit 2e35294

Browse files
Merge pull request #2977 from DhairyaLGandhi/dg/crc
AD: Add ChainRules extension for MTKParameters construction
2 parents c113dd1 + 93908be commit 2e35294

File tree

5 files changed

+53
-0
lines changed

5 files changed

+53
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,12 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
5757

5858
[weakdeps]
5959
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
60+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6061
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
6162

6263
[extensions]
6364
MTKBifurcationKitExt = "BifurcationKit"
65+
MTKChainRulesCoreExt = "ChainRulesCore"
6466
MTKDeepDiffsExt = "DeepDiffs"
6567

6668
[compat]

ext/MTKChainRulesCoreExt.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module MTKChainRulesCoreExt
2+
3+
import ModelingToolkit as MTK
4+
import ChainRulesCore
5+
import ChainRulesCore: NoTangent
6+
7+
function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
8+
function mtp_pullback(dt)
9+
(NoTangent(), dt.tunable[1:length(tunables)], ntuple(_ -> NoTangent(), length(args))...)
10+
end
11+
MTK.MTKParameters(tunables, args...), mtp_pullback
12+
end
13+
14+
end

test/extensions/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
[deps]
22
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
33
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
4+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
5+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
6+
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
7+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
8+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

test/extensions/ad.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using ModelingToolkit
2+
using ModelingToolkit: t_nounits as t, D_nounits as D
3+
using Zygote
4+
using SymbolicIndexingInterface
5+
using SciMLStructures
6+
using OrdinaryDiffEq
7+
using SciMLSensitivity
8+
9+
@variables x(t)[1:3] y(t)
10+
@parameters p[1:3, 1:3] q
11+
eqs = [
12+
D(x) ~ p * x
13+
D(y) ~ sum(p) + q * y
14+
]
15+
u0 = [x => zeros(3),
16+
y => 1.]
17+
ps = [p => zeros(3, 3),
18+
q => 1.]
19+
tspan = (0., 10.)
20+
@mtkbuild sys = ODESystem(eqs, t)
21+
prob = ODEProblem(sys, u0, tspan, ps)
22+
sol = solve(prob, Tsit5())
23+
24+
mtkparams = parameter_values(prob)
25+
new_p = rand(10)
26+
gs = gradient(new_p) do new_p
27+
new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p)
28+
new_prob = remake(prob, p = new_params)
29+
new_sol = solve(new_prob, Tsit5())
30+
sum(new_sol)
31+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,6 @@ end
104104
if GROUP == "All" || GROUP == "Extensions"
105105
activate_extensions_env()
106106
@safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl")
107+
@safetestset "Auto Differentiation Test" include("extensions/ad.jl")
107108
end
108109
end

0 commit comments

Comments
 (0)