@@ -366,3 +366,91 @@ closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens)
366
366
constantify (scens:: AbstractVector{<:Scenario} ) = constantify .(scens)
367
367
cachify (scens:: AbstractVector{<:Scenario} ; use_tuples) = cachify .(scens; use_tuples)
368
368
constantorcachify (scens:: AbstractVector{<:Scenario} ) = constantorcachify .(scens)
369
+
370
+ # # Compute results with backend
371
+
372
+ get_res1 (:: Val , args... ) = nothing
373
+ get_res2 (:: Val , args... ) = nothing
374
+
375
+ function get_res1 (:: Val{:derivative} , f, backend:: AbstractADType , x, contexts... )
376
+ return derivative (f, backend, x, contexts... )
377
+ end
378
+ function get_res1 (:: Val{:derivative} , f!, y, backend:: AbstractADType , x, contexts... )
379
+ return derivative (f!, y, backend, x, contexts... )
380
+ end
381
+ function get_res1 (:: Val{:gradient} , f, backend:: AbstractADType , x, contexts... )
382
+ return gradient (f, backend, x, contexts... )
383
+ end
384
+ function get_res1 (:: Val{:jacobian} , f, backend:: AbstractADType , x, contexts... )
385
+ return jacobian (f, backend, x, contexts... )
386
+ end
387
+ function get_res1 (:: Val{:jacobian} , f!, y, backend:: AbstractADType , x, contexts... )
388
+ return jacobian (f!, y, backend, x, contexts... )
389
+ end
390
+ function get_res1 (:: Val{:second_derivative} , f, backend:: AbstractADType , x, contexts... )
391
+ return derivative (f, backend, x, contexts... )
392
+ end
393
+ function get_res1 (:: Val{:hessian} , f, backend:: AbstractADType , x, contexts... )
394
+ return gradient (f, backend, x, contexts... )
395
+ end
396
+
397
+ function get_res2 (:: Val{:second_derivative} , f, backend:: AbstractADType , x, contexts... )
398
+ return second_derivative (f, backend, x, contexts... )
399
+ end
400
+ function get_res2 (:: Val{:hessian} , f, backend:: AbstractADType , x, contexts... )
401
+ return hessian (f, backend, x, contexts... )
402
+ end
403
+
404
+ function get_res1 (:: Val{:pushforward} , f, backend:: AbstractADType , x, t, contexts... )
405
+ return pushforward (f, backend, x, t, contexts... )
406
+ end
407
+ function get_res1 (:: Val{:pushforward} , f!, y, backend:: AbstractADType , x, t, contexts... )
408
+ return pushforward (f!, y, backend, x, t, contexts... )
409
+ end
410
+ function get_res1 (:: Val{:pullback} , f, backend:: AbstractADType , x, t, contexts... )
411
+ return pullback (f, backend, x, t, contexts... )
412
+ end
413
+ function get_res1 (:: Val{:pullback} , f!, y, backend:: AbstractADType , x, t, contexts... )
414
+ return pullback (f!, y, backend, x, t, contexts... )
415
+ end
416
+ function get_res1 (:: Val{:hvp} , f, backend:: AbstractADType , x, t, contexts... )
417
+ return gradient (f, backend, x, contexts... )
418
+ end
419
+
420
+ function get_res2 (:: Val{:hvp} , f, backend:: AbstractADType , x, t, contexts... )
421
+ return hvp (f, backend, x, t, contexts... )
422
+ end
423
+
424
+ """
425
+ compute_results(scen::Scenario, backend::AbstractADType)
426
+
427
+ Return a scenario identical to `scen` but where the first- and second-order results `res1` and `res2` have been computed with the given differentiation `backend`.
428
+
429
+ Useful for comparison of outputs between backends.
430
+ """
431
+ function compute_results (
432
+ scen:: Scenario{op,pl_op,pl_fun} , backend:: AbstractADType
433
+ ) where {op,pl_op,pl_fun}
434
+ (; f, y, x, t, contexts, prep_args, name) = deepcopy (scen)
435
+ if pl_fun == :in
436
+ if isnothing (t)
437
+ new_res1 = get_res1 (Val (op), f, y, backend, x, contexts... )
438
+ new_res2 = get_res2 (Val (op), f, y, backend, x, contexts... )
439
+ else
440
+ new_res1 = get_res1 (Val (op), f, y, backend, x, t, contexts... )
441
+ new_res2 = get_res2 (Val (op), f, y, backend, x, t, contexts... )
442
+ end
443
+ else
444
+ if isnothing (t)
445
+ new_res1 = get_res1 (Val (op), f, backend, x, contexts... )
446
+ new_res2 = get_res2 (Val (op), f, backend, x, contexts... )
447
+ else
448
+ new_res1 = get_res1 (Val (op), f, backend, x, t, contexts... )
449
+ new_res2 = get_res2 (Val (op), f, backend, x, t, contexts... )
450
+ end
451
+ end
452
+ new_scen = Scenario {op,pl_op,pl_fun} (;
453
+ f, x, y, t, contexts, res1= new_res1, res2= new_res2, prep_args, name
454
+ )
455
+ return new_scen
456
+ end
0 commit comments