Skip to content

Commit 917b775

Browse files
authored
Merge pull request #283 from JuliaDynamics/hw/sparsity
automatic sparsity patterns for networks
2 parents a50157a + 40c0842 commit 917b775

19 files changed

+780
-35
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# NetworkDynamics Release Notes
22

3+
## v0.10.1 Changelog
4+
- [#283](https://github.com/JuliaDynamics/NetworkDynamics.jl/pull/283) add automatic sparsity detection using `get_jac_prototype` and `set_jac_prototype!`
5+
36
## v0.10 Changelog
47
- **BREAKING**: the interface initialization of components has changed: it is now split up in two versions, mutating and non mutating version. Also it errors now if the tolerance bounds are violated. See docs on initialization for more details.
58

NetworkDynamicsInspector/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NetworkDynamicsInspector"
22
uuid = "0a4713f2-d58f-43f2-b63b-1b5d5ee4e65a"
33
authors = ["Hans Würfel <[email protected]>"]
4-
version = "0.1.7"
4+
version = "0.1.8"
55

66
[deps]
77
Bonito = "824d6782-a2ef-11e9-3a09-e5662e0c26f8"

Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,16 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3939
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4040
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
4141
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
42+
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
43+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
4244
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
4345
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
4446

4547
[extensions]
4648
NetworkDynamicsCUDAExt = ["CUDA", "Adapt"]
4749
NetworkDynamicsDataFramesExt = ["DataFrames"]
4850
NetworkDynamicsMTKExt = ["ModelingToolkit", "SymbolicUtils", "Symbolics"]
51+
NetworkDynamicsSparsityExt = ["SparseConnectivityTracer", "RuntimeGeneratedFunctions"]
4952
NetworkDynamicsSymbolicsExt = ["Symbolics", "MacroTools"]
5053

5154
[compat]
@@ -74,8 +77,10 @@ PrecompileTools = "1.2.1"
7477
Printf = "1.10.0"
7578
Random = "1"
7679
RecursiveArrayTools = "3.27.0"
80+
RuntimeGeneratedFunctions = "0.5.15"
7781
SciMLBase = "2"
7882
SparseArrays = "1"
83+
SparseConnectivityTracer = "0.6"
7984
Static = "1.1.1"
8085
StaticArrays = "1.9.4"
8186
SteadyStateDiffEq = "2.2.0"

docs/Project.toml

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ name = "ND.jl docs"
33
[deps]
44
Bonito = "824d6782-a2ef-11e9-3a09-e5662e0c26f8"
55
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
6+
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
67
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
78
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
89
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
@@ -27,20 +28,15 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2728
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2829
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2930
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
31+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
3032
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3133
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3234
StyledStrings = "f489334b-da3d-4c2e-b8f0-e476e12c162b"
3335
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
3436

35-
[sources.Bonito]
36-
rev = "master"
37-
url = "https://github.com/SimonDanisch/Bonito.jl"
38-
39-
[sources.NetworkDynamics]
40-
path = ".."
41-
42-
[sources.NetworkDynamicsInspector]
43-
path = "../NetworkDynamicsInspector"
37+
[sources]
38+
NetworkDynamics = {path = ".."}
39+
NetworkDynamicsInspector = {path = "../NetworkDynamicsInspector"}
4440

4541
[compat]
4642
Bonito = "≥0.0.1"

docs/make.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using ModelingToolkit
1010
using DataFrames: DataFrames
1111
using DocumenterInterLinks
1212
using Electron
13+
using SparseConnectivityTracer
1314

1415

1516
@info "Create global electron window"
@@ -36,10 +37,11 @@ end
3637

3738
mtkext = Base.get_extension(NetworkDynamics, :NetworkDynamicsMTKExt)
3839
dfext = Base.get_extension(NetworkDynamics, :NetworkDynamicsDataFramesExt)
40+
sparsityext = Base.get_extension(NetworkDynamics, :NetworkDynamicsSparsityExt)
3941
kwargs = (;
4042
root=joinpath(pkgdir(NetworkDynamics), "docs"),
4143
sitename="NetworkDynamics",
42-
modules=[NetworkDynamics, mtkext, dfext, NetworkDynamicsInspector],
44+
modules=[NetworkDynamics, mtkext, dfext, sparsityext, NetworkDynamicsInspector],
4345
linkcheck=true, # checks if external links resolve
4446
pagesonly=true,
4547
plugins=[links],
@@ -54,6 +56,7 @@ kwargs = (;
5456
"initialization.md",
5557
"callbacks.md",
5658
"mtk_integration.md",
59+
"sparsity_detection.md",
5760
"external_inputs.md",
5861
"inspector.md",
5962
],

docs/src/API.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,12 @@ set_callback!
200200
add_callback!
201201
```
202202

203+
## Sparsity Detection
204+
```@docs
205+
get_jac_prototype
206+
set_jac_prototype!
207+
```
208+
203209
## Execution Types
204210
```@docs
205211
ExecutionStyle

docs/src/sparsity_detection.md

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Sparsity Detection
2+
3+
NetworkDynamics.jl can automatically detect and exploit the sparsity structure of the Jacobian matrix to significantly improve the performance of ODE solvers. This feature uses [SparseConnectivityTracer.jl](https://github.com/adrhill/SparseConnectivityTracer.jl) to analyze the network's dynamics and create a sparse Jacobian prototype that modern solvers can use for more efficient linear algebra operations.
4+
5+
The sparsity detection is particularly beneficial for:
6+
- Large networks where the Jacobian matrix is sparse
7+
- Stiff systems that require implicit solvers
8+
- Networks with complex component interactions
9+
- Components with conditional statements that complicate automatic differentiation
10+
11+
## Core Function
12+
13+
The main interface is the [`get_jac_prototype`](@ref) function, which take the a `Network` object as an argument and returns a sparse boolean matrix containing the sparsity pattern.
14+
15+
The sparsity pattern can be passed to ODE solvers to improve performance:
16+
17+
```julia
18+
f_ode = ODEFunction(nw; jac_prototype=get_jac_prototype(nw))
19+
prob = ODEProblem(f_ode, x0, (0.0, 1.0), p0)
20+
sol = solve(prob, Rodas5P())
21+
```
22+
23+
Alternatively, you can store the sparsity pattern directly in the network:
24+
25+
```julia
26+
set_jac_prototype!(nw; kwargs_for_get_jac_prototype...)
27+
prob = ODEProblem(nw, x0, (0.0, 1.0), p0) # automatically uses stored prototype
28+
```
29+
30+
## Example: Handling Conditional Statements
31+
32+
A key feature of NetworkDynamics.jl's sparsity detection is the ability to handle conditional statements in component functions. This is particularly useful for ModelingToolkit-based components that use `ifelse` statements.
33+
34+
The conditional statements will be resolved in favor of a "global" sparsity pattern by
35+
replacing them temporarily with `trueblock + falseblock` which is then inferable by
36+
SparseConnectivityTracer.jl.
37+
38+
!!! details "Setup code"
39+
```@example sparsity
40+
using NetworkDynamics, ModelingToolkit, Graphs
41+
using SparseArrays, OrdinaryDiffEqRosenbrock, OrdinaryDiffEqNonlinearSolve
42+
using ModelingToolkit: D_nounits as Dt, t_nounits as t
43+
nothing #hide
44+
```
45+
46+
```@example sparsity
47+
# Define a component with conditional logic
48+
@mtkmodel ValveModel begin
49+
@variables begin
50+
p_src(t), [description="source pressure"]
51+
p_dst(t), [description="destination pressure"]
52+
q(t), [description="flow through valve"]
53+
end
54+
@parameters begin
55+
K=1, [description="conductance"]
56+
active=1, [description="valve state"]
57+
end
58+
@equations begin
59+
q ~ ifelse(active > 0, K * (p_src - p_dst), 0)
60+
end
61+
end
62+
63+
@mtkmodel NodeModel begin
64+
@variables begin
65+
p(t)=1, [description="pressure"]
66+
q_nw(t), [description="network flow"]
67+
end
68+
@parameters begin
69+
C=1, [description="capacitance"]
70+
q_ext, [description="external flow"]
71+
end
72+
@equations begin
73+
C*Dt(p) ~ q_ext + q_nw
74+
end
75+
end
76+
nothing # hide
77+
```
78+
79+
```@example sparsity
80+
# Create network
81+
@named valve = ValveModel()
82+
@named node = NodeModel()
83+
84+
g = wheel_graph(10)
85+
v = VertexModel(node, [:q_nw], [:p])
86+
e = EdgeModel(valve, [:p_src], [:p_dst], AntiSymmetric([:q]))
87+
88+
nw = Network(g, v, e)
89+
```
90+
91+
```@example sparsity
92+
# This will fail due to conditional statements
93+
try
94+
get_jac_prototype(nw)
95+
catch
96+
println("Error: Sparsity detection failed due to conditional statements")
97+
end
98+
```
99+
100+
```@example sparsity
101+
# This works by removing conditionals
102+
jac_prototype = get_jac_prototype(nw; remove_conditions=true)
103+
104+
# Store the prototype directly in the network
105+
set_jac_prototype!(nw, jac_prototype)
106+
```
107+
108+
## Performance Benefits
109+
110+
Using sparsity detection can significantly improve solver performance, especially for large networks and stiff systems:
111+
112+
```@example sparsity
113+
using OrdinaryDiffEqRosenbrock, Chairmarks
114+
115+
# Create a large sparse network for benchmarking
116+
g_large = grid([20, 20]) # 400 nodes in a 2D grid (very sparse)
117+
nw_large = Network(g_large, v, e)
118+
119+
# Setup initial conditions and parameters
120+
using Random # hide
121+
Random.seed!(42) # hide
122+
s0 = NWState(nw_large)
123+
s0.v[:, :p] .= randn(400) # random initial pressures
124+
125+
p0 = NWParameter(nw_large)
126+
p0.v[:, :q_ext] .= randn(400) # small external flow
127+
128+
nothing #hide
129+
```
130+
131+
The network is now ready for benchmarking. Let's first time the solution without sparsity detection:
132+
133+
```@example sparsity
134+
# Without sparsity detection (dense Jacobian)
135+
prob_dense = ODEProblem(nw_large, uflat(s0), (0.0, 1.0), pflat(p0))
136+
@b solve($prob_dense, Rodas5P()) seconds=1
137+
```
138+
139+
Now let's enable sparsity detection:
140+
```@example sparsity
141+
jac = get_jac_prototype(nw_large; remove_conditions=true)
142+
```
143+
The pattern already shows that the Jacobian is really sparse due to the sparse network connections.
144+
145+
```@example sparsity
146+
set_jac_prototype!(nw_large, jac)
147+
```
148+
149+
Now we can benchmark the sparse version:
150+
```@example sparsity
151+
# Solve with sparsity detection
152+
prob_sparse = ODEProblem(nw_large, uflat(s0), (0.0, 1.0), pflat(p0))
153+
@b solve($prob_sparse, Rodas5P()) seconds=1
154+
```
155+
156+
For this network, we see a substantial speedup due to the sparse solver!
157+
158+
## Troubleshooting
159+
160+
**Sparsity detection fails with conditional statements:**
161+
- Use `remove_conditions=true` to handle `ifelse` statements in MTK components
162+
- For specific problematic components, pass a vector of indices: `remove_conditions=[EIndex(1), VIndex(2)]`
163+
164+
**Detection fails for complex components:**
165+
- Use `dense=true` to treat all components as dense (fallback option)
166+
- For specific components, use `dense=[EIndex(1)]` to treat only those components as dense
167+
168+
**Performance doesn't improve:**
169+
- Sparsity detection is most beneficial for large networks (>50 nodes) with sparse connectivity
170+
- Dense networks or small systems may not see significant speedup
171+
- Ensure you're using a solver that can exploit sparsity (e.g., `Rodas5P`, `FBDF`)
172+
173+
The sparsity detection feature requires the `SparseConnectivityTracer.jl` package, which needs to be loaded manually!

ext/NetworkDynamicsMTKExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ _all_rhs_symbols(eqs) = mapreduce(eq->get_variables(eq.rhs), ∪, eqs, init=Set{
557557

558558
using PrecompileTools: @compile_workload
559559
@compile_workload begin
560-
include("MTKExt_precomp_workload.jl")
560+
# include("MTKExt_precomp_workload.jl")
561561
end
562562

563563
end

0 commit comments

Comments
 (0)