Skip to content

Commit 8843031

Browse files
authored
Warm up for Tapir's rrule (#254)
* Warm up for Tapir's rrule * Almost recursive warm up
1 parent 8b35cfd commit 8843031

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/onearg.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@ struct TapirOneArgPullbackExtras{Y,R} <: PullbackExtras
33
rrule::R
44
end
55

6-
function DI.prepare_pullback(f, ::AutoTapir, x, dy)
6+
function DI.prepare_pullback(f, backend::AutoTapir, x, dy)
77
y = f(x)
8-
return TapirOneArgPullbackExtras(y, build_rrule(f, x))
8+
rrule = build_rrule(f, x)
9+
extras = TapirOneArgPullbackExtras(y, rrule)
10+
DI.value_and_pullback(f, backend, x, dy, extras) # warm up
11+
return extras
912
end
1013

1114
function DI.value_and_pullback(

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/twoarg.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ struct TapirTwoArgPullbackExtras{R} <: PullbackExtras
22
rrule::R
33
end
44

5-
function DI.prepare_pullback(f!, y, ::AutoTapir, x, dy)
6-
return TapirTwoArgPullbackExtras(build_rrule(f!, y, x))
5+
function DI.prepare_pullback(f!, y, backend::AutoTapir, x, dy)
6+
rrule = build_rrule(f!, y, x)
7+
extras = TapirTwoArgPullbackExtras(rrule)
8+
DI.value_and_pullback(f!, y, backend, x, dy, extras) # warm up
9+
return extras
710
end
811

912
# see https://github.com/withbayes/Tapir.jl/issues/113#issuecomment-2036718992

0 commit comments

Comments
 (0)