@@ -366,3 +366,91 @@ closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens)
366366constantify (scens:: AbstractVector{<:Scenario} ) = constantify .(scens)
367367cachify (scens:: AbstractVector{<:Scenario} ; use_tuples) = cachify .(scens; use_tuples)
368368constantorcachify (scens:: AbstractVector{<:Scenario} ) = constantorcachify .(scens)
369+
370+ # # Compute results with backend
371+
372+ get_res1 (:: Val , args... ) = nothing
373+ get_res2 (:: Val , args... ) = nothing
374+
375+ function get_res1 (:: Val{:derivative} , f, backend:: AbstractADType , x, contexts... )
376+ return derivative (f, backend, x, contexts... )
377+ end
378+ function get_res1 (:: Val{:derivative} , f!, y, backend:: AbstractADType , x, contexts... )
379+ return derivative (f!, y, backend, x, contexts... )
380+ end
381+ function get_res1 (:: Val{:gradient} , f, backend:: AbstractADType , x, contexts... )
382+ return gradient (f, backend, x, contexts... )
383+ end
384+ function get_res1 (:: Val{:jacobian} , f, backend:: AbstractADType , x, contexts... )
385+ return jacobian (f, backend, x, contexts... )
386+ end
387+ function get_res1 (:: Val{:jacobian} , f!, y, backend:: AbstractADType , x, contexts... )
388+ return jacobian (f!, y, backend, x, contexts... )
389+ end
390+ function get_res1 (:: Val{:second_derivative} , f, backend:: AbstractADType , x, contexts... )
391+ return derivative (f, backend, x, contexts... )
392+ end
393+ function get_res1 (:: Val{:hessian} , f, backend:: AbstractADType , x, contexts... )
394+ return gradient (f, backend, x, contexts... )
395+ end
396+
397+ function get_res2 (:: Val{:second_derivative} , f, backend:: AbstractADType , x, contexts... )
398+ return second_derivative (f, backend, x, contexts... )
399+ end
400+ function get_res2 (:: Val{:hessian} , f, backend:: AbstractADType , x, contexts... )
401+ return hessian (f, backend, x, contexts... )
402+ end
403+
404+ function get_res1 (:: Val{:pushforward} , f, backend:: AbstractADType , x, t, contexts... )
405+ return pushforward (f, backend, x, t, contexts... )
406+ end
407+ function get_res1 (:: Val{:pushforward} , f!, y, backend:: AbstractADType , x, t, contexts... )
408+ return pushforward (f!, y, backend, x, t, contexts... )
409+ end
410+ function get_res1 (:: Val{:pullback} , f, backend:: AbstractADType , x, t, contexts... )
411+ return pullback (f, backend, x, t, contexts... )
412+ end
413+ function get_res1 (:: Val{:pullback} , f!, y, backend:: AbstractADType , x, t, contexts... )
414+ return pullback (f!, y, backend, x, t, contexts... )
415+ end
416+ function get_res1 (:: Val{:hvp} , f, backend:: AbstractADType , x, t, contexts... )
417+ return gradient (f, backend, x, contexts... )
418+ end
419+
420+ function get_res2 (:: Val{:hvp} , f, backend:: AbstractADType , x, t, contexts... )
421+ return hvp (f, backend, x, t, contexts... )
422+ end
423+
424+ """
425+ compute_results(scen::Scenario, backend::AbstractADType)
426+
427+ Return a scenario identical to `scen` but where the first- and second-order results `res1` and `res2` have been computed with the given differentiation `backend`.
428+
429+ Useful for comparison of outputs between backends.
430+ """
431+ function compute_results (
432+ scen:: Scenario{op,pl_op,pl_fun} , backend:: AbstractADType
433+ ) where {op,pl_op,pl_fun}
434+ (; f, y, x, t, contexts, prep_args, name) = deepcopy (scen)
435+ if pl_fun == :in
436+ if isnothing (t)
437+ new_res1 = get_res1 (Val (op), f, y, backend, x, contexts... )
438+ new_res2 = get_res2 (Val (op), f, y, backend, x, contexts... )
439+ else
440+ new_res1 = get_res1 (Val (op), f, y, backend, x, t, contexts... )
441+ new_res2 = get_res2 (Val (op), f, y, backend, x, t, contexts... )
442+ end
443+ else
444+ if isnothing (t)
445+ new_res1 = get_res1 (Val (op), f, backend, x, contexts... )
446+ new_res2 = get_res2 (Val (op), f, backend, x, contexts... )
447+ else
448+ new_res1 = get_res1 (Val (op), f, backend, x, t, contexts... )
449+ new_res2 = get_res2 (Val (op), f, backend, x, t, contexts... )
450+ end
451+ end
452+ new_scen = Scenario {op,pl_op,pl_fun} (;
453+ f, x, y, t, contexts, res1= new_res1, res2= new_res2, prep_args, name
454+ )
455+ return new_scen
456+ end
0 commit comments