Skip to content

Commit a15e8d1

Browse files
Fix ADTests
1 parent 8134c74 commit a15e8d1

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

lib/OrdinaryDiffEqCore/src/misc_utils.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ function _bool_to_ADType(::Val{false}, _, ::Val{FD}) where {FD}
147147
end
148148

149149
# Functions to get ADType type from Bool or ADType object, or ADType type
150+
function _process_AD_choice(ad_alg::Bool, CS::Int, ::Val{FD}) where {FD}
151+
return _bool_to_ADType(Val(ad_alg), Val{CS}(), Val{FD}()), Val{CS}(), Val{FD}()
152+
end
153+
150154
function _process_AD_choice(ad_alg::Bool, ::Val{CS}, ::Val{FD}) where {CS, FD}
151155
return _bool_to_ADType(Val(ad_alg), Val{CS}(), Val{FD}()), Val{CS}(), Val{FD}()
152156
end
@@ -162,6 +166,17 @@ function _process_AD_choice(
162166
return ad_alg, Val{_CS}(), Val{FD}()
163167
end
164168

169+
function _process_AD_choice(
170+
ad_alg::AutoForwardDiff{CS}, CS2::Int, ::Val{FD}) where {CS, FD}
171+
# Non-default `chunk_size`
172+
if CS2 != 0
173+
@warn "The `chunk_size` keyword is deprecated. Please use an `ADType` specifier. For now defaulting to using `AutoForwardDiff` with `chunksize=$(CS2)`."
174+
return _bool_to_ADType(Val{true}(), Val{CS2}(), Val{FD}()), Val{CS2}(), Val{FD}()
175+
end
176+
_CS = CS === nothing ? 0 : CS
177+
return ad_alg, Val{_CS}(), Val{FD}()
178+
end
179+
165180
function _process_AD_choice(
166181
ad_alg::AutoFiniteDiff{FD}, ::Val{CS}, ::Val{FD2}) where {FD, CS, FD2}
167182
# Non-default `diff_type`

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -763,8 +763,9 @@ function alg_cache(
763763
du1 = zero(rate_prototype)
764764
du2 = zero(rate_prototype)
765765

766-
dtC = similar(tab.C)
767-
dtd = similar(tab.d)
766+
# Promote t-type for AD
767+
dtC = similar(tab.C) .* dt .* false
768+
dtd = similar(tab.d) .* dt .* false
768769

769770
# Initialize other variables
770771
fsalfirst = zero(rate_prototype)

0 commit comments

Comments
 (0)