Skip to content

Commit 0aa03c7

Browse files
Discuss using MTK in loss functions in the FAQ
Fixes #1692
1 parent 3b829f1 commit 0aa03c7

File tree

1 file changed

+54
-1
lines changed

1 file changed

+54
-1
lines changed

docs/src/basics/FAQ.md

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,58 @@ ERROR: TypeError: non-boolean (Num) used in boolean context
4141

4242
then it's likely you are trying to trace through a function which cannot be
4343
directly represented in Julia symbols. The techniques to handle this problem,
44-
such as `@register_symbolic`, are described in detail
44+
such as `@register_symbolic`, are described in detail
4545
[in the Symbolics.jl documentation](https://symbolics.juliasymbolics.org/dev/manual/faq/#Transforming-my-function-to-a-symbolic-equation-has-failed.-What-do-I-do?-1).
46+
47+
## Using ModelingToolkit with Optimization / Automatic Differentiation
48+
49+
If you are using ModelingToolkit inside of a loss function and are having issues with
50+
mixing MTK with automatic differentiation, getting performance, etc... don't! Instead, use
51+
MTK outside of the loss function to generate the code, and then use the generated code
52+
inside of the loss function.
53+
54+
For example, let's say you were building ODEProblems in the loss function like:
55+
56+
```julia
57+
function loss(p)
58+
prob = ODEProblem(sys, [], [p1 => p[1], p2 => p[2]])
59+
sol = solve(prob, Tsit5())
60+
sum(abs2,sol)
61+
end
62+
```
63+
64+
Since `ODEProblem` on a MTK `sys` will have to generate code, this will be slower than
65+
caching the generated code, and will required automatic differentiation to go through the
66+
code generation process itself. All of this is unnecessary. Instead, generate the problem
67+
once outside of the loss function, and remake the prob inside of the loss function:
68+
69+
```julia
70+
prob = ODEProblem(sys, [], [p1 => p[1], p2 => p[2]])
71+
function loss(p)
72+
remake(prob,p = ...)
73+
sol = solve(prob, Tsit5())
74+
sum(abs2,sol)
75+
end
76+
```
77+
78+
Now, one has to be careful with `remake` to ensure that the parameters are in the right
79+
order. One can use the previously mentioned indexing functionality to generate index
80+
maps for reordering `p` like:
81+
82+
```julia
83+
p = @parameters x y z
84+
idxs = ModelingToolkit.varmap_to_vars([p[1] => 1, p[2] => 2, p[3] => 3], p)
85+
p[idxs]
86+
```
87+
88+
Using this, the fixed index map can be used in the loss function. This would look like:
89+
90+
```julia
91+
prob = ODEProblem(sys, [], [p1 => p[1], p2 => p[2]])
92+
idxs = Int.(ModelingToolkit.varmap_to_vars([p1 => 1, p2 => 2], p))
93+
function loss(p)
94+
remake(prob,p = p[idxs])
95+
sol = solve(prob, Tsit5())
96+
sum(abs2,sol)
97+
end
98+
```

0 commit comments

Comments
 (0)