Skip to content

Commit ff95f72

Browse files
authored
fix: check ForwardDiff tag manually to exclude contexts from it (#740)
1 parent 864f215 commit ff95f72

File tree

3 files changed

+86
-25
lines changed

3 files changed

+86
-25
lines changed

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using ForwardDiff:
1313
HessianConfig,
1414
JacobianConfig,
1515
Tag,
16+
checktag,
1617
derivative,
1718
derivative!,
1819
extract_derivative,

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ function DI.prepare_gradient(
313313
) where {F,C}
314314
fc = DI.with_contexts(f, contexts...)
315315
chunk = choose_chunk(backend, x)
316-
tag = get_tag(fc, backend, x)
316+
tag = get_tag(f, backend, x)
317317
config = GradientConfig(fc, x, chunk, tag)
318318
return ForwardDiffGradientPrep(config)
319319
end
@@ -329,7 +329,10 @@ function DI.value_and_gradient!(
329329
fc = DI.with_contexts(f, contexts...)
330330
result = DiffResult(zero(eltype(x)), (grad,))
331331
CHK = tag_type(backend) === Nothing
332-
result = gradient!(result, fc, x, prep.config, Val(CHK))
332+
if CHK
333+
checktag(prep.config, f, x)
334+
end
335+
result = gradient!(result, fc, x, prep.config, Val(false))
333336
y = DR.value(result)
334337
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
335338
return y, grad
@@ -345,7 +348,10 @@ function DI.value_and_gradient(
345348
fc = DI.with_contexts(f, contexts...)
346349
result = GradientResult(x)
347350
CHK = tag_type(backend) === Nothing
348-
result = gradient!(result, fc, x, prep.config, Val(CHK))
351+
if CHK
352+
checktag(prep.config, f, x)
353+
end
354+
result = gradient!(result, fc, x, prep.config, Val(false))
349355
return DR.value(result), DR.gradient(result)
350356
end
351357

@@ -359,7 +365,10 @@ function DI.gradient!(
359365
) where {F,C}
360366
fc = DI.with_contexts(f, contexts...)
361367
CHK = tag_type(backend) === Nothing
362-
return gradient!(grad, fc, x, prep.config, Val(CHK))
368+
if CHK
369+
checktag(prep.config, f, x)
370+
end
371+
return gradient!(grad, fc, x, prep.config, Val(false))
363372
end
364373

365374
function DI.gradient(
@@ -371,7 +380,10 @@ function DI.gradient(
371380
) where {F,C}
372381
fc = DI.with_contexts(f, contexts...)
373382
CHK = tag_type(backend) === Nothing
374-
return gradient(fc, x, prep.config, Val(CHK))
383+
if CHK
384+
checktag(prep.config, f, x)
385+
end
386+
return gradient(fc, x, prep.config, Val(false))
375387
end
376388

377389
## Jacobian
@@ -456,7 +468,7 @@ function DI.prepare_jacobian(
456468
) where {F,C}
457469
fc = DI.with_contexts(f, contexts...)
458470
chunk = choose_chunk(backend, x)
459-
tag = get_tag(fc, backend, x)
471+
tag = get_tag(f, backend, x)
460472
config = JacobianConfig(fc, x, chunk, tag)
461473
return ForwardDiffOneArgJacobianPrep(config)
462474
end
@@ -473,7 +485,10 @@ function DI.value_and_jacobian!(
473485
y = fc(x)
474486
result = DiffResult(y, (jac,))
475487
CHK = tag_type(backend) === Nothing
476-
result = jacobian!(result, fc, x, prep.config, Val(CHK))
488+
if CHK
489+
checktag(prep.config, f, x)
490+
end
491+
result = jacobian!(result, fc, x, prep.config, Val(false))
477492
y = DR.value(result)
478493
jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result))
479494
return y, jac
@@ -488,7 +503,10 @@ function DI.value_and_jacobian(
488503
) where {F,C}
489504
fc = DI.with_contexts(f, contexts...)
490505
CHK = tag_type(backend) === Nothing
491-
return fc(x), jacobian(fc, x, prep.config, Val(CHK))
506+
if CHK
507+
checktag(prep.config, f, x)
508+
end
509+
return fc(x), jacobian(fc, x, prep.config, Val(false))
492510
end
493511

494512
function DI.jacobian!(
@@ -501,7 +519,10 @@ function DI.jacobian!(
501519
) where {F,C}
502520
fc = DI.with_contexts(f, contexts...)
503521
CHK = tag_type(backend) === Nothing
504-
return jacobian!(jac, fc, x, prep.config, Val(CHK))
522+
if CHK
523+
checktag(prep.config, f, x)
524+
end
525+
return jacobian!(jac, fc, x, prep.config, Val(false))
505526
end
506527

507528
function DI.jacobian(
@@ -513,7 +534,10 @@ function DI.jacobian(
513534
) where {F,C}
514535
fc = DI.with_contexts(f, contexts...)
515536
CHK = tag_type(backend) === Nothing
516-
return jacobian(fc, x, prep.config, Val(CHK))
537+
if CHK
538+
checktag(prep.config, f, x)
539+
end
540+
return jacobian(fc, x, prep.config, Val(false))
517541
end
518542

519543
## Second derivative
@@ -738,7 +762,7 @@ function DI.prepare_hessian(
738762
) where {F,C}
739763
fc = DI.with_contexts(f, contexts...)
740764
chunk = choose_chunk(backend, x)
741-
tag = get_tag(fc, backend, x)
765+
tag = get_tag(f, backend, x)
742766
result = HessianResult(x)
743767
array_config = HessianConfig(fc, x, chunk, tag)
744768
result_config = HessianConfig(fc, result, x, chunk, tag)
@@ -755,7 +779,10 @@ function DI.hessian!(
755779
) where {F,C}
756780
fc = DI.with_contexts(f, contexts...)
757781
CHK = tag_type(backend) === Nothing
758-
return hessian!(hess, fc, x, prep.array_config, Val(CHK))
782+
if CHK
783+
checktag(prep.array_config, f, x)
784+
end
785+
return hessian!(hess, fc, x, prep.array_config, Val(false))
759786
end
760787

761788
function DI.hessian(
@@ -767,7 +794,10 @@ function DI.hessian(
767794
) where {F,C}
768795
fc = DI.with_contexts(f, contexts...)
769796
CHK = tag_type(backend) === Nothing
770-
return hessian(fc, x, prep.array_config, Val(CHK))
797+
if CHK
798+
checktag(prep.array_config, f, x)
799+
end
800+
return hessian(fc, x, prep.array_config, Val(false))
771801
end
772802

773803
function DI.value_gradient_and_hessian!(
@@ -782,7 +812,10 @@ function DI.value_gradient_and_hessian!(
782812
fc = DI.with_contexts(f, contexts...)
783813
result = DiffResult(one(eltype(x)), (grad, hess))
784814
CHK = tag_type(backend) === Nothing
785-
result = hessian!(result, fc, x, prep.result_config, Val(CHK))
815+
if CHK
816+
checktag(prep.result_config, f, x)
817+
end
818+
result = hessian!(result, fc, x, prep.result_config, Val(false))
786819
y = DR.value(result)
787820
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
788821
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
@@ -799,6 +832,9 @@ function DI.value_gradient_and_hessian(
799832
fc = DI.with_contexts(f, contexts...)
800833
result = HessianResult(x)
801834
CHK = tag_type(backend) === Nothing
802-
result = hessian!(result, fc, x, prep.result_config, Val(CHK))
835+
if CHK
836+
checktag(prep.result_config, f, x)
837+
end
838+
result = hessian!(result, fc, x, prep.result_config, Val(false))
803839
return (DR.value(result), DR.gradient(result), DR.hessian(result))
804840
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ function DI.prepare_derivative(
194194
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
195195
) where {F,C}
196196
fc! = DI.with_contexts(f!, contexts...)
197-
tag = get_tag(fc!, backend, x)
197+
tag = get_tag(f!, backend, x)
198198
config = DerivativeConfig(fc!, y, x, tag)
199199
return ForwardDiffTwoArgDerivativePrep(config)
200200
end
@@ -227,7 +227,10 @@ function DI.value_and_derivative(
227227
fc! = DI.with_contexts(f!, contexts...)
228228
result = MutableDiffResult(y, (similar(y),))
229229
CHK = tag_type(backend) === Nothing
230-
result = derivative!(result, fc!, y, x, prep.config, Val(CHK))
230+
if CHK
231+
checktag(prep.config, f!, x)
232+
end
233+
result = derivative!(result, fc!, y, x, prep.config, Val(false))
231234
return DiffResults.value(result), DiffResults.derivative(result)
232235
end
233236

@@ -243,7 +246,10 @@ function DI.value_and_derivative!(
243246
fc! = DI.with_contexts(f!, contexts...)
244247
result = MutableDiffResult(y, (der,))
245248
CHK = tag_type(backend) === Nothing
246-
result = derivative!(result, fc!, y, x, prep.config, Val(CHK))
249+
if CHK
250+
checktag(prep.config, f!, x)
251+
end
252+
result = derivative!(result, fc!, y, x, prep.config, Val(false))
247253
return DiffResults.value(result), DiffResults.derivative(result)
248254
end
249255

@@ -257,7 +263,10 @@ function DI.derivative(
257263
) where {F,C}
258264
fc! = DI.with_contexts(f!, contexts...)
259265
CHK = tag_type(backend) === Nothing
260-
return derivative(fc!, y, x, prep.config, Val(CHK))
266+
if CHK
267+
checktag(prep.config, f!, x)
268+
end
269+
return derivative(fc!, y, x, prep.config, Val(false))
261270
end
262271

263272
function DI.derivative!(
@@ -271,7 +280,10 @@ function DI.derivative!(
271280
) where {F,C}
272281
fc! = DI.with_contexts(f!, contexts...)
273282
CHK = tag_type(backend) === Nothing
274-
return derivative!(der, fc!, y, x, prep.config, Val(CHK))
283+
if CHK
284+
checktag(prep.config, f!, x)
285+
end
286+
return derivative!(der, fc!, y, x, prep.config, Val(false))
275287
end
276288

277289
## Jacobian
@@ -364,7 +376,7 @@ function DI.prepare_jacobian(
364376
) where {F,C}
365377
fc! = DI.with_contexts(f!, contexts...)
366378
chunk = choose_chunk(backend, x)
367-
tag = get_tag(fc!, backend, x)
379+
tag = get_tag(f!, backend, x)
368380
config = JacobianConfig(fc!, y, x, chunk, tag)
369381
return ForwardDiffTwoArgJacobianPrep(config)
370382
end
@@ -400,7 +412,10 @@ function DI.value_and_jacobian(
400412
jac = similar(y, length(y), length(x))
401413
result = MutableDiffResult(y, (jac,))
402414
CHK = tag_type(backend) === Nothing
403-
result = jacobian!(result, fc!, y, x, prep.config, Val(CHK))
415+
if CHK
416+
checktag(prep.config, f!, x)
417+
end
418+
result = jacobian!(result, fc!, y, x, prep.config, Val(false))
404419
return DiffResults.value(result), DiffResults.jacobian(result)
405420
end
406421

@@ -416,7 +431,10 @@ function DI.value_and_jacobian!(
416431
fc! = DI.with_contexts(f!, contexts...)
417432
result = MutableDiffResult(y, (jac,))
418433
CHK = tag_type(backend) === Nothing
419-
result = jacobian!(result, fc!, y, x, prep.config, Val(CHK))
434+
if CHK
435+
checktag(prep.config, f!, x)
436+
end
437+
result = jacobian!(result, fc!, y, x, prep.config, Val(false))
420438
return DiffResults.value(result), DiffResults.jacobian(result)
421439
end
422440

@@ -430,7 +448,10 @@ function DI.jacobian(
430448
) where {F,C}
431449
fc! = DI.with_contexts(f!, contexts...)
432450
CHK = tag_type(backend) === Nothing
433-
return jacobian(fc!, y, x, prep.config, Val(CHK))
451+
if CHK
452+
checktag(prep.config, f!, x)
453+
end
454+
return jacobian(fc!, y, x, prep.config, Val(false))
434455
end
435456

436457
function DI.jacobian!(
@@ -444,5 +465,8 @@ function DI.jacobian!(
444465
) where {F,C}
445466
fc! = DI.with_contexts(f!, contexts...)
446467
CHK = tag_type(backend) === Nothing
447-
return jacobian!(jac, fc!, y, x, prep.config, Val(CHK))
468+
if CHK
469+
checktag(prep.config, f!, x)
470+
end
471+
return jacobian!(jac, fc!, y, x, prep.config, Val(false))
448472
end

0 commit comments

Comments
 (0)