87
87
function DI. value_and_gradient! (
88
88
f, grad, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff{compile} , x
89
89
) where {compile}
90
- y = f (x) # TODO : ReverseDiff#251
91
- result = DiffResult (y, (grad,))
90
+ result = MutableDiffResult (zero (eltype (x)), (grad,)) # ReverseDiff#251
92
91
if compile
93
92
result = gradient! (result, prep. tape, x)
94
93
else
95
94
result = gradient! (result, f, x, prep. config)
96
95
end
97
- y = DR. value (result)
98
- grad === DR. gradient (result) || copyto! (grad, DR. gradient (result))
99
- return y, grad
96
+ return DR. value (result), grad # ReverseDiff#269
100
97
end
101
98
102
99
function DI. value_and_gradient (
103
- f, prep:: ReverseDiffGradientPrep , backend:: AutoReverseDiff , x
104
- )
105
- grad = similar (x)
106
- return DI. value_and_gradient! (f, grad, prep, backend, x)
100
+ f, prep:: ReverseDiffGradientPrep , backend:: AutoReverseDiff{compile} , x
101
+ ) where {compile}
102
+ # GradientResult tries to mutate an SArray
103
+ result = MutableDiffResult (zero (eltype (x)), (similar (x),))
104
+ if compile
105
+ result = gradient! (result, prep. tape, x)
106
+ else
107
+ result = gradient! (result, f, x, prep. config)
108
+ end
109
+ return DR. value (result), DR. gradient (result)
107
110
end
108
111
109
112
function DI. gradient! (
@@ -144,23 +147,19 @@ function DI.value_and_gradient!(
144
147
contexts:: Vararg{DI.Context,C} ,
145
148
) where {C}
146
149
fc = DI. with_contexts (f, contexts... )
147
- y = fc (x) # TODO : ReverseDiff#251
148
- result = DiffResult (y, (grad,))
150
+ result = MutableDiffResult (zero (eltype (x)), (grad,)) # ReverseDiff#251
149
151
result = gradient! (result, fc, x, prep. config)
150
- y = DR. value (result)
151
- grad === DR. gradient (result) || copyto! (grad, DR. gradient (result))
152
- return y, grad
152
+ return DR. value (result), grad # ReverseDiff#269
153
153
end
154
154
155
155
function DI. value_and_gradient (
156
- f,
157
- prep:: ReverseDiffGradientPrep ,
158
- backend:: AutoReverseDiff ,
159
- x,
160
- contexts:: Vararg{DI.Context,C} ,
156
+ f, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x, contexts:: Vararg{DI.Context,C}
161
157
) where {C}
162
- grad = similar (x)
163
- return DI. value_and_gradient! (f, grad, prep, backend, x, contexts... )
158
+ fc = DI. with_contexts (f, contexts... )
159
+ # GradientResult tries to mutate an SArray
160
+ result = MutableDiffResult (zero (eltype (x)), (similar (x),))
161
+ result = gradient! (result, fc, x, prep. config)
162
+ return DR. value (result), DR. gradient (result)
164
163
end
165
164
166
165
function DI. gradient! (
@@ -310,31 +309,23 @@ end
310
309
311
310
# ## Without contexts
312
311
313
- @kwdef struct ReverseDiffHessianPrep{GC ,HC,GT ,HT} <: DI.HessianPrep
314
- gradient_config :: GC
312
+ @kwdef struct ReverseDiffHessianPrep{G <: ReverseDiffGradientPrep ,HC,HT} <: DI.HessianPrep
313
+ gradient_prep :: G
315
314
hessian_config:: HC
316
- gradient_tape:: GT
317
315
hessian_tape:: HT
318
316
end
319
317
320
- function DI. prepare_hessian (f, :: AutoReverseDiff{compile} , x) where {compile}
318
+ function DI. prepare_hessian (f, backend:: AutoReverseDiff{compile} , x) where {compile}
319
+ gradient_prep = DI. prepare_gradient (f, backend, x)
321
320
if compile
322
- gradient_tape = ReverseDiff. compile (GradientTape (f, x))
323
321
hessian_tape = ReverseDiff. compile (HessianTape (f, x))
324
322
return ReverseDiffHessianPrep (;
325
- gradient_config= nothing ,
326
- hessian_config= nothing ,
327
- gradient_tape= gradient_tape,
328
- hessian_tape= hessian_tape,
323
+ gradient_prep, hessian_config= nothing , hessian_tape= hessian_tape
329
324
)
330
325
else
331
- gradient_config = GradientConfig (x)
332
326
hessian_config = HessianConfig (x)
333
327
return ReverseDiffHessianPrep (;
334
- gradient_config= gradient_config,
335
- hessian_config= hessian_config,
336
- gradient_tape= nothing ,
337
- hessian_tape= nothing ,
328
+ gradient_prep, hessian_config= hessian_config, hessian_tape= nothing
338
329
)
339
330
end
340
331
end
@@ -360,47 +351,32 @@ function DI.hessian(
360
351
end
361
352
362
353
function DI. value_gradient_and_hessian! (
363
- f, grad, hess, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff{compile} , x
354
+ f, grad, hess, prep:: ReverseDiffHessianPrep , backend :: AutoReverseDiff{compile} , x
364
355
) where {compile}
365
- y = f (x) # TODO : ReverseDiff#251
366
- result = DiffResult (y, (grad, hess))
367
- if compile
368
- result = hessian! (result, prep. hessian_tape, x)
369
- grad = gradient! (grad, prep. gradient_tape, x) # TODO : ReverseDiff#251
370
- else
371
- result = hessian! (result, f, x) # TODO : add prep.hessian_config
372
- grad = gradient! (grad, f, x, prep. gradient_config) # TODO : ReverseDiff#251
373
- end
374
- # grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
375
- hess === DR. hessian (result) || copyto! (hess, DR. hessian (result))
356
+ y = f (x)
357
+ DI. gradient! (f, grad, prep. gradient_prep, backend, x)
358
+ DI. hessian! (f, hess, prep, backend, x)
376
359
return y, grad, hess
377
360
end
378
361
379
362
function DI. value_gradient_and_hessian (
380
- f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff{compile} , x
363
+ f, prep:: ReverseDiffHessianPrep , backend :: AutoReverseDiff{compile} , x
381
364
) where {compile}
382
- y = f (x) # TODO : remove once ReverseDiff#251 is fixed
383
- result = DiffResult (y, (similar (x), similar (x, length (x), length (x))))
384
- if compile
385
- result = hessian! (result, prep. hessian_tape, x)
386
- else
387
- result = hessian! (result, f, x) # todo: add prep.hessian_config
388
- end
389
- return (y, DR. gradient (result), DR. hessian (result))
365
+ y = f (x)
366
+ grad = DI. gradient (f, prep. gradient_prep, backend, x)
367
+ hess = DI. hessian (f, prep, backend, x)
368
+ return y, grad, hess
390
369
end
391
370
392
371
# ## With contexts
393
372
394
373
function DI. prepare_hessian (
395
- f, :: AutoReverseDiff , x, contexts:: Vararg{DI.Context,C}
374
+ f, backend :: AutoReverseDiff , x, contexts:: Vararg{DI.Context,C}
396
375
) where {C}
397
- gradient_config = GradientConfig (x )
376
+ gradient_prep = DI . prepare_gradient (f, backend, x, contexts ... )
398
377
hessian_config = HessianConfig (x)
399
378
return ReverseDiffHessianPrep (;
400
- gradient_config= gradient_config,
401
- hessian_config= hessian_config,
402
- gradient_tape= nothing ,
403
- hessian_tape= nothing ,
379
+ gradient_prep, hessian_config= hessian_config, hessian_tape= nothing
404
380
)
405
381
end
406
382
@@ -428,27 +404,25 @@ function DI.value_gradient_and_hessian!(
428
404
grad,
429
405
hess,
430
406
prep:: ReverseDiffHessianPrep ,
431
- :: AutoReverseDiff ,
407
+ backend :: AutoReverseDiff ,
432
408
x,
433
409
contexts:: Vararg{DI.Context,C} ,
434
410
) where {C}
435
- fc = DI. with_contexts (f, contexts... )
436
- y = fc (x) # TODO : ReverseDiff#251
437
- result = DiffResult (y, (grad, hess))
438
- result = hessian! (result, fc, x) # TODO : add prep.hessian_config
439
- y = DR. value (result)
440
- # grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
441
- grad = gradient! (grad, fc, x, prep. gradient_config) # TODO : ReverseDiff#251
442
- hess === DR. hessian (result) || copyto! (hess, DR. hessian (result))
411
+ y = f (x, map (DI. unwrap, contexts)... )
412
+ DI. gradient! (f, grad, prep. gradient_prep, backend, x, contexts... )
413
+ DI. hessian! (f, hess, prep, backend, x, contexts... )
443
414
return y, grad, hess
444
415
end
445
416
446
417
function DI. value_gradient_and_hessian (
447
- f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x, contexts:: Vararg{DI.Context,C}
418
+ f,
419
+ prep:: ReverseDiffHessianPrep ,
420
+ backend:: AutoReverseDiff ,
421
+ x,
422
+ contexts:: Vararg{DI.Context,C} ,
448
423
) where {C}
449
- fc = DI. with_contexts (f, contexts... )
450
- y = fc (x) # TODO : ReverseDiff#251
451
- result = HessianResult (x)
452
- result = hessian! (result, fc, x) # TODO : add prep.hessian_config
453
- return (DR. value (result), DR. gradient (result), DR. hessian (result))
424
+ y = f (x, map (DI. unwrap, contexts)... )
425
+ grad = DI. gradient (f, prep. gradient_prep, backend, x, contexts... )
426
+ hess = DI. hessian (f, prep, backend, x, contexts... )
427
+ return y, grad, hess
454
428
end
0 commit comments