Skip to content

Commit 91970fe

Browse files
Update PreallocationToolsForwardDiffExt.jl
1 parent d74ddb2 commit 91970fe

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

ext/PreallocationToolsForwardDiffExt.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,16 @@ using ForwardDiff
55
using ArrayInterface
66
using Adapt
77

8-
# Initialize on module load
9-
function __init__()
10-
# Set the dual array creator function
11-
PreallocationTools.DUAL_ARRAY_CREATOR[] = function(u::AbstractArray{T}, siz,
8+
function PreallocationTools.dualarraycreator(u::AbstractArray{T}, siz,
129
::Type{Val{chunk_size}}) where {T, chunk_size}
13-
ArrayInterface.restructure(u,
14-
zeros(ForwardDiff.Dual{Nothing, T, chunk_size},
15-
siz...))
16-
end
17-
18-
# Set the chunk size function to use ForwardDiff's pickchunksize
19-
PreallocationTools.CHUNK_SIZE_FUNC[] = ForwardDiff.pickchunksize
10+
ArrayInterface.restructure(u,
11+
zeros(ForwardDiff.Dual{Nothing, T, chunk_size},
12+
siz...))
2013
end
2114

15+
PreallocationTools.pickchunksize(x::Number) = ForwardDiff.pickchunksize(x)
16+
PreallocationTools.pickchunksize(x::AbstractArray) = ForwardDiff.pickchunksize(x)
17+
2218
# Define chunksize for ForwardDiff.Dual types
2319
PreallocationTools.chunksize(::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = N
2420

@@ -86,4 +82,4 @@ function PreallocationTools.get_tmp(dc::PreallocationTools.DiffCache, u::Abstrac
8682
end
8783
end
8884

89-
end
85+
end

0 commit comments

Comments
 (0)