Skip to content

Commit 7da8910

Browse files
committed
Add extension for SparseConnectivityTracer
1 parent 9b1b2d7 commit 7da8910

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PreallocationTools"
22
uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "0.4.24"
4+
version = "0.4.25"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -10,9 +10,11 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1010

1111
[weakdeps]
1212
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
13+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
1314

1415
[extensions]
1516
PreallocationToolsReverseDiffExt = "ReverseDiff"
17+
PreallocationToolsSparseConnectivityTracerExt = "SparseConnectivityTracer"
1618

1719
[compat]
1820
Adapt = "3.4, 4"
@@ -30,6 +32,7 @@ RecursiveArrayTools = "3.2"
3032
ReverseDiff = "1"
3133
SafeTestsets = "0.1"
3234
SparseArrays = "1"
35+
SparseConnectivityTracer = "0.6.12"
3336
Symbolics = "5.12"
3437
Test = "1"
3538
julia = "1.10"
@@ -47,8 +50,9 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
4750
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
4851
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4952
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
53+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
5054
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
5155
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5256

5357
[targets]
54-
test = ["Aqua", "Random", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SparseArrays", "Symbolics"]
58+
test = ["Aqua", "Random", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SparseArrays", "Symbolics", "SparseConnectivityTracer"]
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
module PreallocationToolsSparseConnectivityTracerExt
2+
3+
using PreallocationTools
4+
isdefined(Base, :get_extension) ? (import SparseConnectivityTracer) :
5+
(import ..SparseConnectivityTracer)
6+
7+
function PreallocationTools.get_tmp(
8+
dc::DiffCache, u::T) where {T <: SparseConnectivityTracer.Dual}
9+
if isbitstype(T)
10+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
11+
if nelem > length(dc.dual_du)
12+
PreallocationTools.enlargediffcache!(dc, nelem)
13+
end
14+
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
15+
else
16+
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
17+
end
18+
end
19+
20+
function PreallocationTools.get_tmp(
21+
dc::DiffCache, ::Type{T}) where {T <: SparseConnectivityTracer.Dual}
22+
if isbitstype(T)
23+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
24+
if nelem > length(dc.dual_du)
25+
PreallocationTools.enlargediffcache!(dc, nelem)
26+
end
27+
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
28+
else
29+
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
30+
end
31+
end
32+
33+
function PreallocationTools.get_tmp(
34+
dc::DiffCache, u::AbstractArray{T}) where {T <: SparseConnectivityTracer.Dual}
35+
if isbitstype(T)
36+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
37+
if nelem > length(dc.dual_du)
38+
PreallocationTools.enlargediffcache!(dc, nelem)
39+
end
40+
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
41+
else
42+
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
43+
end
44+
end
45+
46+
end

0 commit comments

Comments
 (0)