Skip to content

Commit 8b12f2e

Browse files
committed
Enable tag checking
1 parent 1311519 commit 8b12f2e

File tree

3 files changed

+26
-23
lines changed

3 files changed

+26
-23
lines changed

Manifest.toml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
164164

165165
[[deps.DiffEqBase]]
166166
deps = ["ArrayInterface", "ChainRulesCore", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "Requires", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces", "ZygoteRules"]
167-
git-tree-sha1 = "94384b09e50ea01819b6db01ac08403ebe09bf65"
168-
repo-rev = "ap/tstable_termination"
167+
git-tree-sha1 = "53ad089996089756cae5a098b1a0542aeaab466f"
168+
repo-rev = "master"
169169
repo-url = "https://github.com/SciML/DiffEqBase.jl"
170170
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
171171
version = "6.136.0"
@@ -425,9 +425,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
425425

426426
[[deps.LinearSolve]]
427427
deps = ["ArrayInterface", "ConcreteStructs", "DocStringExtensions", "EnumX", "EnzymeCore", "FastLapackInterface", "GPUArraysCore", "InteractiveUtils", "KLU", "Krylov", "Libdl", "LinearAlgebra", "MKL_jll", "PrecompileTools", "Preferences", "RecursiveFactorization", "Reexport", "Requires", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Sparspak", "SuiteSparse", "UnPack"]
428-
git-tree-sha1 = "27732d23d88534a7b735dcf8f411daf34293a39e"
428+
git-tree-sha1 = "9f807ca41005f9a8f092716e48022ee5b36cf5b1"
429429
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
430-
version = "2.14.0"
430+
version = "2.14.1"
431431

432432
[deps.LinearSolve.extensions]
433433
LinearSolveBandedMatricesExt = "BandedMatrices"
@@ -639,9 +639,9 @@ version = "1.3.4"
639639

640640
[[deps.RecursiveArrayTools]]
641641
deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "Requires", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"]
642-
git-tree-sha1 = "d7087c013e8a496ff396bae843b1e16d9a30ede8"
642+
git-tree-sha1 = "fa453b42ba1623bd2e70260bf44dac850a3430a7"
643643
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
644-
version = "2.38.10"
644+
version = "2.39.0"
645645

646646
[deps.RecursiveArrayTools.extensions]
647647
RecursiveArrayToolsMeasurementsExt = "Measurements"
@@ -689,9 +689,9 @@ version = "0.1.0"
689689

690690
[[deps.SLEEFPirates]]
691691
deps = ["IfElse", "Static", "VectorizationBase"]
692-
git-tree-sha1 = "f5c896d781486f1d67c8492f0e0ead2c3517208c"
692+
git-tree-sha1 = "3aac6d68c5e57449f5b9b865c9ba50ac2970c4cf"
693693
uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa"
694-
version = "0.6.41"
694+
version = "0.6.42"
695695

696696
[[deps.SciMLBase]]
697697
deps = ["ADTypes", "ArrayInterface", "ChainRulesCore", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FillArrays", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces", "ZygoteRules"]
@@ -766,7 +766,7 @@ version = "1.10.0"
766766

767767
[[deps.SparseDiffTools]]
768768
deps = ["ADTypes", "Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "Graphs", "LinearAlgebra", "PackageExtensionCompat", "Random", "Reexport", "SciMLOperators", "Setfield", "SparseArrays", "StaticArrayInterface", "StaticArrays", "Tricks", "UnPack", "VertexSafeGraphs"]
769-
git-tree-sha1 = "5188e5e415908a19a41cd90d8ab74a23affacba6"
769+
git-tree-sha1 = "888937b8348e1e9ffae1c31efa61e693bc5463ba"
770770
repo-rev = "ap/tagging"
771771
repo-url = "https://github.com/avik-pal/SparseDiffTools.jl"
772772
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

src/utils.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ end
1010

1111
struct NonlinearSolveTag end
1212

13+
function ForwardDiff.checktag(::Type{<:ForwardDiff.Tag{<:NonlinearSolveTag, <:T}}, f::F,
14+
x::AbstractArray{T}) where {T, F}
15+
return true
16+
end
17+
1318
"""
1419
default_adargs_to_adtype(; chunk_size = Val{0}(), autodiff = Val{true}(),
1520
standardtag = Val{true}(), diff_type = Val{:forward})
@@ -43,7 +48,8 @@ function default_adargs_to_adtype(; chunk_size = missing, autodiff = nothing,
4348

4449
ad = _unwrap_val(autodiff)
4550
# We don't really know the typeof the input yet, so we can't use the correct tag!
46-
ad && return AutoForwardDiff{_unwrap_val(chunk_size), Nothing}(nothing)
51+
ad && return AutoForwardDiff{_unwrap_val(chunk_size), NonlinearSolveTag}(;
52+
tag = NonlinearSolveTag())
4753
return AutoFiniteDiff(; fdtype = diff_type)
4854
end
4955

@@ -117,17 +123,6 @@ function wrapprecs(_Pl, _Pr, weight)
117123
return Pl, Pr
118124
end
119125

120-
function _nfcount(N, ::Type{diff_type}) where {diff_type}
121-
if diff_type === Val{:complex}
122-
tmp = N
123-
elseif diff_type === Val{:forward}
124-
tmp = N + 1
125-
else
126-
tmp = 2N
127-
end
128-
return tmp
129-
end
130-
131126
get_loss(fu) = norm(fu)^2 / 2
132127

133128
function rfunc(r::R, c2::R, M::R, γ1::R, γ2::R, β::R) where {R <: Real} # R-function for adaptive trust region method
@@ -203,7 +198,7 @@ function __get_concrete_algorithm(alg, prob)
203198
use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff()
204199
else
205200
(use_sparse_ad ? AutoSparseForwardDiff : AutoForwardDiff)(;
206-
tag = ForwardDiff.Tag(NonlinearSolveTag(), eltype(prob.u0)))
201+
tag = NonlinearSolveTag())
207202
end
208203
return set_ad(alg, ad)
209204
end

test/sparse.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ end
4141
u0 = init_brusselator_2d(xyd_brusselator)
4242
prob_brusselator_2d = NonlinearProblem(brusselator_2d_loop, u0, p)
4343
sol = solve(prob_brusselator_2d, NewtonRaphson())
44+
@test norm(sol.resid) < 1e-8
45+
46+
sol = solve(prob_brusselator_2d, NewtonRaphson(autodiff = AutoSparseForwardDiff()))
47+
@test norm(sol.resid) < 1e-8
48+
49+
sol = solve(prob_brusselator_2d, NewtonRaphson(autodiff = AutoSparseFiniteDiff()))
50+
@test norm(sol.resid) < 1e-8
4451

4552
du0 = copy(u0)
4653
jac_sparsity = Symbolics.jacobian_sparsity((du, u) -> brusselator_2d_loop(du, u, p), du0,
@@ -57,7 +64,8 @@ sol = solve(prob_brusselator_2d, NewtonRaphson())
5764
@test !all(iszero, jac_prototype)
5865

5966
sol = solve(prob_brusselator_2d, NewtonRaphson(autodiff = AutoSparseFiniteDiff()))
60-
@test norm(sol.resid) < 1e-6
67+
@test norm(sol.resid) < 1e-8
6168

6269
cache = init(prob_brusselator_2d, NewtonRaphson(; autodiff = AutoSparseForwardDiff()));
6370
@test maximum(cache.jac_cache.coloring.colorvec) == 12
71+
@test cache.alg.ad isa AutoSparseForwardDiff

0 commit comments

Comments
 (0)