Skip to content

Commit c6c0c04

Browse files
Revert "removed generated function"
This reverts commit c2e7131.
1 parent c2e7131 commit c6c0c04

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

ext/DiffEqBaseForwardDiffExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ if !hasmethod(nextfloat, Tuple{ForwardDiff.Dual})
215215
end
216216
end
217217

218+
# bisection(f, tup::Tuple{T,T}, t_forward::Bool) where {T<:ForwardDiff.Dual} = find_zero(f, tup, Roots.AlefeldPotraShi())
219+
218220
# Static Arrays don't support the `init` keyword argument for `sum`
219221
@inline __sum(f::F, args...; init, kwargs...) where {F} = sum(f, args...; init, kwargs...)
220222
@inline function __sum(

src/utils.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -190,22 +190,23 @@ upconversion is not done automatically, the user is required to upconvert all in
190190
themselves, for an example of how this can be confusing to a user see
191191
https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a-forced-differential-equation/82937
192192
"""
193-
function anyeltypedual(x, ::Type{Val{counter}} = Val{0}) where {counter}
194-
if isdualtype(typeof(x))
195-
x
193+
@generated function anyeltypedual(x, ::Type{Val{counter}} = Val{0}) where {counter}
194+
x = x.name === Core.Compiler.typename(Type) ? x.parameters[1] : x
195+
if isdualtype(x)
196+
:($x)
196197
elseif fieldnames(x) === ()
197-
Any
198+
:(Any)
198199
elseif counter < DUALCHECK_RECURSION_MAX
199200
T = diffeqmapreduce(x -> anyeltypedual(x, Val{counter + 1}), promote_dual,
200201
x.parameters)
201202
if T === Any || isconcretetype(T)
202-
T
203+
:($T)
203204
else
204-
diffeqmapreduce(DualEltypeChecker(typeof(x), counter + 1), promote_dual,
205-
map(Val, fieldnames((typeof(x)))))
205+
:(diffeqmapreduce(DualEltypeChecker($x, $counter + 1), promote_dual,
206+
map(Val, fieldnames($(typeof(x))))))
206207
end
207208
else
208-
Any
209+
:(Any)
209210
end
210211
end
211212

0 commit comments

Comments
 (0)