Skip to content

Commit ea2744d

Browse files
committed
bipass all Diffractor machinery if there is no partials
1 parent da2c0bb commit ea2744d

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

src/stage1/forward.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,18 @@ function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N}
137137
∂☆p(ZeroBundle{N-1}(frule), #= ZeroBundle{N-1}(DiffractorRuleConfig()), =# tupargs, map(primal, downargs)...)
138138
end
139139

140+
# Special shortcut case if there is no derivative information at all:
141+
function (::∂☆internal{N})(f::AbstractZeroBundle{N}, args::AbstractZeroBundle{N}...) where {N}
142+
f_v = primal(f)
143+
args_v = map(primal, args)
144+
return ZeroBundle{N}(f_v(args_v...))
145+
end
146+
function (::∂☆internal{1})(f::AbstractZeroBundle{1}, args::AbstractZeroBundle{1}...)
147+
f_v = primal(f)
148+
args_v = map(primal, args)
149+
return ZeroBundle{1}(f_v(args_v...))
150+
end
151+
140152
function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N}
141153
r = ∂☆shuffle{N}()(args...)
142154
if primal(r) === nothing
@@ -147,6 +159,7 @@ function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N}
147159
end
148160
(::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...)
149161

162+
150163
# Special case rules for performance
151164
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
152165
s = primal(s)

test/forward.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module forward_tests
22
using Diffractor
3-
using Diffractor: var"'", ∂⃖, DiffractorRuleConfig
3+
using Diffractor: var"'", ∂⃖, DiffractorRuleConfig, ZeroBundle
44
using ChainRules
55
using ChainRulesCore
66
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
@@ -50,5 +50,23 @@ let var"'" = Diffractor.PrimeDerivativeFwd
5050
end
5151

5252

53+
@testset "No partials" begin
54+
primal_calls = Ref(0)
55+
function foo(x, y)
56+
primal_calls[]+=1
57+
return x+y
58+
end
59+
60+
frule_calls = Ref(0)
61+
function ChainRulesCore.frule((_, ẋ, ẏ), ::typeof(foo), x, y)
62+
frule_calls[]+=1
63+
return x+y, ẋ+
64+
end
65+
66+
# Special case if there is no derivative information at all:
67+
@test (Diffractor.∂☆{1}())(ZeroBundle{1}(foo), ZeroBundle{1}(2.0), ZeroBundle{1}(3.0)) == ZeroBundle{1}(5.0)
68+
@test frule_calls[] == 0
69+
@test primal_calls[] == 1
70+
end
5371

5472
end

0 commit comments

Comments
 (0)