Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit f546c7c

Browse files
Remove autmatic FoR soadtype creations
1 parent 1cb8a90 commit f546c7c

File tree

1 file changed

+11
-57
lines changed

1 file changed

+11
-57
lines changed

src/adtypes.jl

Lines changed: 11 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,14 @@ Hessian is not defined via Zygote.
220220
AutoZygote
221221

222222
function generate_adtype(adtype)
223-
if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode
224-
soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote?
225-
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
226-
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
227-
else
223+
if !(adtype isa SciMLBase.NoAD && adtype isa DifferentiationInterface.SecondOrder)
224+
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)
225+
elseif adtype isa DifferentiationInterface.SecondOrder
226+
soadtype = adtype
227+
adtype = adtype.inner
228+
elseif adtype isa SciMLBase.NoAD
228229
soadtype = adtype
230+
adtype = adtype
229231
end
230232
return adtype, soadtype
231233
end
@@ -235,86 +237,38 @@ function generate_sparse_adtype(adtype)
235237
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
236238
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
237239
coloring_algorithm = GreedyColoringAlgorithm())
238-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
240+
if !(adtype.dense_ad isa SciMLBase.NoAD && adtype.dense_ad isa DifferentiationInterface.SecondOrder)
239241
soadtype = AutoSparse(
240242
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
241243
sparsity_detector = TracerSparsityDetector(),
242244
coloring_algorithm = GreedyColoringAlgorithm())
243-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
244-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
245-
soadtype = AutoSparse(
246-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
247-
sparsity_detector = TracerSparsityDetector(),
248-
coloring_algorithm = GreedyColoringAlgorithm()) #make zygote?
249-
elseif !(adtype isa SciMLBase.NoAD) &&
250-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
251-
soadtype = AutoSparse(
252-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
253-
sparsity_detector = TracerSparsityDetector(),
254-
coloring_algorithm = GreedyColoringAlgorithm())
255245
end
256246
elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
257247
!(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm)
258248
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
259249
coloring_algorithm = adtype.coloring_algorithm)
260-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
250+
if !(adtype.dense_ad isa SciMLBase.NoAD && adtype.dense_ad isa DifferentiationInterface.SecondOrder)
261251
soadtype = AutoSparse(
262252
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
263253
sparsity_detector = TracerSparsityDetector(),
264254
coloring_algorithm = adtype.coloring_algorithm)
265-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
266-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
267-
soadtype = AutoSparse(
268-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
269-
sparsity_detector = TracerSparsityDetector(),
270-
coloring_algorithm = adtype.coloring_algorithm)
271-
elseif !(adtype isa SciMLBase.NoAD) &&
272-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
273-
soadtype = AutoSparse(
274-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
275-
sparsity_detector = TracerSparsityDetector(),
276-
coloring_algorithm = adtype.coloring_algorithm)
277255
end
278256
elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) &&
279257
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
280258
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector,
281259
coloring_algorithm = GreedyColoringAlgorithm())
282-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
260+
if !(adtype.dense_ad isa SciMLBase.NoAD && adtype.dense_ad isa DifferentiationInterface.SecondOrder)
283261
soadtype = AutoSparse(
284262
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
285263
sparsity_detector = adtype.sparsity_detector,
286264
coloring_algorithm = GreedyColoringAlgorithm())
287-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
288-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
289-
soadtype = AutoSparse(
290-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
291-
sparsity_detector = adtype.sparsity_detector,
292-
coloring_algorithm = GreedyColoringAlgorithm())
293-
elseif !(adtype isa SciMLBase.NoAD) &&
294-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
295-
soadtype = AutoSparse(
296-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
297-
sparsity_detector = adtype.sparsity_detector,
298-
coloring_algorithm = GreedyColoringAlgorithm())
299265
end
300266
else
301-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
267+
if !(adtype.dense_ad isa SciMLBase.NoAD && adtype.dense_ad isa DifferentiationInterface.SecondOrder)
302268
soadtype = AutoSparse(
303269
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
304270
sparsity_detector = adtype.sparsity_detector,
305271
coloring_algorithm = adtype.coloring_algorithm)
306-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
307-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
308-
soadtype = AutoSparse(
309-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
310-
sparsity_detector = adtype.sparsity_detector,
311-
coloring_algorithm = adtype.coloring_algorithm)
312-
elseif !(adtype isa SciMLBase.NoAD) &&
313-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
314-
soadtype = AutoSparse(
315-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
316-
sparsity_detector = adtype.sparsity_detector,
317-
coloring_algorithm = adtype.coloring_algorithm)
318272
end
319273
end
320274
return adtype, soadtype

0 commit comments

Comments
 (0)