Skip to content

Commit 3d5bee0

Browse files
authored
stage2(forward): add method table backedge for non-existing frule method (#182)
1 parent 2462184 commit 3d5bee0

File tree

3 files changed

+112
-19
lines changed

3 files changed

+112
-19
lines changed

Manifest.toml

Lines changed: 83 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/analysis/forward.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,19 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
1919
frule_preargtypes = Any[Const(ChainRulesCore.frule), Tuple{Nothing,Vararg{Any,nargs}}]
2020
frule_argtypes = append!(frule_preargtypes, arginfo.argtypes)
2121
frule_arginfo = ArgInfo(nothing, frule_argtypes)
22+
frule_si = StmtInfo(true)
23+
frule_atype = CC.argtypes_to_type(frule_argtypes)
2224
# turn off frule analysis in the frule to avoid cycling
2325
interp′ = disable_forward(interp)
24-
frule_call = CC.abstract_call_known(interp′, ChainRulesCore.frule, frule_arginfo, StmtInfo(true), sv, #=max_methods=#-1)
26+
frule_call = CC.abstract_call_gf_by_type(interp′,
27+
ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1)
2528
if frule_call.rt !== Const(nothing)
2629
return CallMeta(primal_call.rt, primal_call.effects, FRuleCallInfo(primal_call.info, frule_call))
30+
else
31+
CC.add_mt_backedge!(sv, frule_mt, frule_atype)
2732
end
2833

2934
return nothing
3035
end
36+
37+
const frule_mt = methods(ChainRulesCore.frule).mt

test/stage2_fwd.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,28 @@ module stage2_fwd
1111
end
1212

1313
myminus(a, b) = a - b
14+
self_minus(a) = myminus(a, a)
1415
ChainRulesCore.@scalar_rule myminus(x, y) (true, -1)
16+
let self_minus′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64})
17+
@test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64})
18+
@test self_minus′(1.0) == 0.
19+
end
20+
ChainRulesCore.@scalar_rule myminus(x, y) (true, true) # frule for `x - y`
21+
let self_minus′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64})
22+
@test isa(self_minus′, Core.OpaqueClosure{Tuple{Float64}, Float64})
23+
@test self_minus′(1.0) == 2.
24+
end
1525

16-
self_minus(a) = myminus(a, a)
17-
let self_minus′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}, 2)
18-
@test isa(self_minus′′, Core.OpaqueClosure{Tuple{Float64}, Float64})
19-
@test self_minus′′(1.0) == 0.
26+
myminus2(a, b) = a - b
27+
self_minus2(a) = myminus2(a, a)
28+
let self_minus2′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus2), Float64})
29+
@test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64})
30+
@test self_minus2′(1.0) == 0.
31+
end
32+
ChainRulesCore.@scalar_rule myminus2(x, y) (true, true) # frule for `x - y`
33+
let self_minus2′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus2), Float64})
34+
@test isa(self_minus2′, Core.OpaqueClosure{Tuple{Float64}, Float64})
35+
@test self_minus2′(1.0) == 2.
2036
end
2137

2238
@testset "structs" begin
@@ -43,4 +59,4 @@ module stage2_fwd
4359
g(x) = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(f), Diffractor.TaylorBundle{1}(x, (1.0,)))
4460
Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(g), Diffractor.TaylorBundle{1}(10f0, (1.0,)))
4561
end
46-
end
62+
end

0 commit comments

Comments
 (0)