@@ -313,7 +313,7 @@ function DI.prepare_gradient(
313
313
) where {F,C}
314
314
fc = DI. with_contexts (f, contexts... )
315
315
chunk = choose_chunk (backend, x)
316
- tag = get_tag (fc , backend, x)
316
+ tag = get_tag (f , backend, x)
317
317
config = GradientConfig (fc, x, chunk, tag)
318
318
return ForwardDiffGradientPrep (config)
319
319
end
@@ -329,7 +329,10 @@ function DI.value_and_gradient!(
329
329
fc = DI. with_contexts (f, contexts... )
330
330
result = DiffResult (zero (eltype (x)), (grad,))
331
331
CHK = tag_type (backend) === Nothing
332
- result = gradient! (result, fc, x, prep. config, Val (CHK))
332
+ if CHK
333
+ checktag (prep. config, f, x)
334
+ end
335
+ result = gradient! (result, fc, x, prep. config, Val (false ))
333
336
y = DR. value (result)
334
337
grad === DR. gradient (result) || copyto! (grad, DR. gradient (result))
335
338
return y, grad
@@ -345,7 +348,10 @@ function DI.value_and_gradient(
345
348
fc = DI. with_contexts (f, contexts... )
346
349
result = GradientResult (x)
347
350
CHK = tag_type (backend) === Nothing
348
- result = gradient! (result, fc, x, prep. config, Val (CHK))
351
+ if CHK
352
+ checktag (prep. config, f, x)
353
+ end
354
+ result = gradient! (result, fc, x, prep. config, Val (false ))
349
355
return DR. value (result), DR. gradient (result)
350
356
end
351
357
@@ -359,7 +365,10 @@ function DI.gradient!(
359
365
) where {F,C}
360
366
fc = DI. with_contexts (f, contexts... )
361
367
CHK = tag_type (backend) === Nothing
362
- return gradient! (grad, fc, x, prep. config, Val (CHK))
368
+ if CHK
369
+ checktag (prep. config, f, x)
370
+ end
371
+ return gradient! (grad, fc, x, prep. config, Val (false ))
363
372
end
364
373
365
374
function DI. gradient (
@@ -371,7 +380,10 @@ function DI.gradient(
371
380
) where {F,C}
372
381
fc = DI. with_contexts (f, contexts... )
373
382
CHK = tag_type (backend) === Nothing
374
- return gradient (fc, x, prep. config, Val (CHK))
383
+ if CHK
384
+ checktag (prep. config, f, x)
385
+ end
386
+ return gradient (fc, x, prep. config, Val (false ))
375
387
end
376
388
377
389
# # Jacobian
@@ -456,7 +468,7 @@ function DI.prepare_jacobian(
456
468
) where {F,C}
457
469
fc = DI. with_contexts (f, contexts... )
458
470
chunk = choose_chunk (backend, x)
459
- tag = get_tag (fc , backend, x)
471
+ tag = get_tag (f , backend, x)
460
472
config = JacobianConfig (fc, x, chunk, tag)
461
473
return ForwardDiffOneArgJacobianPrep (config)
462
474
end
@@ -473,7 +485,10 @@ function DI.value_and_jacobian!(
473
485
y = fc (x)
474
486
result = DiffResult (y, (jac,))
475
487
CHK = tag_type (backend) === Nothing
476
- result = jacobian! (result, fc, x, prep. config, Val (CHK))
488
+ if CHK
489
+ checktag (prep. config, f, x)
490
+ end
491
+ result = jacobian! (result, fc, x, prep. config, Val (false ))
477
492
y = DR. value (result)
478
493
jac === DR. jacobian (result) || copyto! (jac, DR. jacobian (result))
479
494
return y, jac
@@ -488,7 +503,10 @@ function DI.value_and_jacobian(
488
503
) where {F,C}
489
504
fc = DI. with_contexts (f, contexts... )
490
505
CHK = tag_type (backend) === Nothing
491
- return fc (x), jacobian (fc, x, prep. config, Val (CHK))
506
+ if CHK
507
+ checktag (prep. config, f, x)
508
+ end
509
+ return fc (x), jacobian (fc, x, prep. config, Val (false ))
492
510
end
493
511
494
512
function DI. jacobian! (
@@ -501,7 +519,10 @@ function DI.jacobian!(
501
519
) where {F,C}
502
520
fc = DI. with_contexts (f, contexts... )
503
521
CHK = tag_type (backend) === Nothing
504
- return jacobian! (jac, fc, x, prep. config, Val (CHK))
522
+ if CHK
523
+ checktag (prep. config, f, x)
524
+ end
525
+ return jacobian! (jac, fc, x, prep. config, Val (false ))
505
526
end
506
527
507
528
function DI. jacobian (
@@ -513,7 +534,10 @@ function DI.jacobian(
513
534
) where {F,C}
514
535
fc = DI. with_contexts (f, contexts... )
515
536
CHK = tag_type (backend) === Nothing
516
- return jacobian (fc, x, prep. config, Val (CHK))
537
+ if CHK
538
+ checktag (prep. config, f, x)
539
+ end
540
+ return jacobian (fc, x, prep. config, Val (false ))
517
541
end
518
542
519
543
# # Second derivative
@@ -738,7 +762,7 @@ function DI.prepare_hessian(
738
762
) where {F,C}
739
763
fc = DI. with_contexts (f, contexts... )
740
764
chunk = choose_chunk (backend, x)
741
- tag = get_tag (fc , backend, x)
765
+ tag = get_tag (f , backend, x)
742
766
result = HessianResult (x)
743
767
array_config = HessianConfig (fc, x, chunk, tag)
744
768
result_config = HessianConfig (fc, result, x, chunk, tag)
@@ -755,7 +779,10 @@ function DI.hessian!(
755
779
) where {F,C}
756
780
fc = DI. with_contexts (f, contexts... )
757
781
CHK = tag_type (backend) === Nothing
758
- return hessian! (hess, fc, x, prep. array_config, Val (CHK))
782
+ if CHK
783
+ checktag (prep. array_config, f, x)
784
+ end
785
+ return hessian! (hess, fc, x, prep. array_config, Val (false ))
759
786
end
760
787
761
788
function DI. hessian (
@@ -767,7 +794,10 @@ function DI.hessian(
767
794
) where {F,C}
768
795
fc = DI. with_contexts (f, contexts... )
769
796
CHK = tag_type (backend) === Nothing
770
- return hessian (fc, x, prep. array_config, Val (CHK))
797
+ if CHK
798
+ checktag (prep. array_config, f, x)
799
+ end
800
+ return hessian (fc, x, prep. array_config, Val (false ))
771
801
end
772
802
773
803
function DI. value_gradient_and_hessian! (
@@ -782,7 +812,10 @@ function DI.value_gradient_and_hessian!(
782
812
fc = DI. with_contexts (f, contexts... )
783
813
result = DiffResult (one (eltype (x)), (grad, hess))
784
814
CHK = tag_type (backend) === Nothing
785
- result = hessian! (result, fc, x, prep. result_config, Val (CHK))
815
+ if CHK
816
+ checktag (prep. result_config, f, x)
817
+ end
818
+ result = hessian! (result, fc, x, prep. result_config, Val (false ))
786
819
y = DR. value (result)
787
820
grad === DR. gradient (result) || copyto! (grad, DR. gradient (result))
788
821
hess === DR. hessian (result) || copyto! (hess, DR. hessian (result))
@@ -799,6 +832,9 @@ function DI.value_gradient_and_hessian(
799
832
fc = DI. with_contexts (f, contexts... )
800
833
result = HessianResult (x)
801
834
CHK = tag_type (backend) === Nothing
802
- result = hessian! (result, fc, x, prep. result_config, Val (CHK))
835
+ if CHK
836
+ checktag (prep. result_config, f, x)
837
+ end
838
+ result = hessian! (result, fc, x, prep. result_config, Val (false ))
803
839
return (DR. value (result), DR. gradient (result), DR. hessian (result))
804
840
end
0 commit comments