Skip to content

Commit 84ea690

Browse files
authored
Merge pull request #127 from JuliaDiff/ox/comb_nightlyfix
Combined nightly fixes
2 parents 3a13dab + 496d032 commit 84ea690

File tree

9 files changed

+121
-115
lines changed

9 files changed

+121
-115
lines changed

src/Diffractor.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ export ∂⃖, gradient
66

77
const CC = Core.Compiler
88

9+
const GENERATORS = Expr[]
10+
911
include("runtime.jl")
1012
include("interface.jl")
1113
include("utils.jl")
@@ -37,4 +39,11 @@ include("debugutils.jl")
3739

3840
include("stage1/termination.jl")
3941

42+
function reload()
43+
@info "reloading Diffractor generators"
44+
for generator in GENERATORS
45+
Core.eval(@__MODULE__, generator)
46+
end
47+
end
48+
4049
end

src/codegen/reverse.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
# Codegen shared by both stage1 and stage2
22

3-
function make_opaque_closure(interp, typ, name, meth_nargs, isva, lno, cis, revs...)
3+
function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, cis, revs...)
44
if interp !== nothing
55
cis.inferred = true
66
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
77
typ, Union{}, cis.rettype, @__MODULE__, cis, lno.line, lno.file, meth_nargs, isva, ()).source
88
return Expr(:new_opaque_closure, typ, Union{}, Any,
99
ocm, revs...)
1010
else
11+
oc_nargs = Int64(meth_nargs)
1112
Expr(:new_opaque_closure, typ, Union{}, Any,
12-
Expr(:opaque_closure_method, name, meth_nargs, isva, lno, cis), revs...)
13+
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, cis), revs...)
1314
end
1415
end
1516

@@ -111,7 +112,8 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
111112
opaque_ci
112113
end
113114

114-
nfixedargs = meth.isva ? meth.nargs - 1 : meth.nargs
115+
nfixedargs = Int(meth.nargs)
116+
meth.isva && (nfixedargs -= 1)
115117

116118
extra_slotnames = Symbol[]
117119
extra_slotflags = UInt8[]
@@ -157,7 +159,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
157159
# TODO: Can we use the same method for each 2nd order of the transform
158160
# (except the last and the first one)
159161
for nc = 1:2:n_closures
160-
arg_accums = Union{Nothing, Vector{Any}}[nothing for i = 1:(meth.nargs)]
162+
arg_accums = Union{Nothing, Vector{Any}}[nothing for i = 1:Int(meth.nargs)]
161163
accums = Union{Nothing, Vector{Any}}[nothing for i = 1:length(ir.stmts)]
162164

163165
opaque_ci = opaque_cis[nc]
@@ -375,7 +377,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
375377
lno = LineNumberNode(1, :none)
376378
next_oc = insert_node_rev!(make_opaque_closure(interp, Tuple{(Any for i = 1:nargs+1)...},
377379
cname(nc+1, N, meth.name),
378-
meth.nargs,
380+
Int(meth.nargs),
379381
meth.isva,
380382
lno,
381383
opaque_cis[nc+1],

src/debugutils.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ using Core.Compiler: AbstractInterpreter, CodeInstance, MethodInstance, WorldVie
22
using InteractiveUtils
33

44
function infer_function(interp, tt)
5+
world = Core.Compiler.get_world_counter()
6+
57
# Find all methods that are applicable to these types
6-
mthds = _methods_by_ftype(tt, -1, typemax(UInt))
8+
mthds = _methods_by_ftype(tt, -1, world)
79
if mthds === false || length(mthds) != 1
810
error("Unable to find single applicable method for $tt")
911
end
@@ -17,7 +19,6 @@ function infer_function(interp, tt)
1719
result = Core.Compiler.InferenceResult(mi)
1820

1921
# Create an InferenceState to begin inference, give it a world that is always newest
20-
world = Core.Compiler.get_world_counter()
2122
frame = Core.Compiler.InferenceState(result, #=cached=# true, interp)
2223

2324
# Run type inference on this frame. Because the interpreter is embedded

src/stage1/generated.jl

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,29 @@ struct ∂⃖recurse{N}; end
66

77
include("recurse.jl")
88

9-
function perform_optic_transform(@nospecialize(ff::Type{∂⃖recurse{N}}), @nospecialize(args)) where {N}
9+
function perform_optic_transform(world::UInt, source::LineNumberNode,
10+
@nospecialize(ff::Type{∂⃖recurse{N}}), @nospecialize(args)) where {N}
1011
@assert N >= 1
1112

1213
# Check if we have an rrule for this function
13-
mthds = Base._methods_by_ftype(Tuple{args...}, -1, typemax(UInt))
14-
if length(mthds) != 1
15-
return :(throw(MethodError(ff, args)))
14+
sig = Tuple{args...}
15+
mthds = Base._methods_by_ftype(sig, -1, world)
16+
if mthds === nothing || length(mthds) != 1
17+
# Core.println("[perform_optic_transform] ", sig, " => ", mthds)
18+
stub = Core.GeneratedFunctionStub(identity, Core.svec(:ff, :args), Core.svec())
19+
return stub(world, source, :(throw(MethodError(ff, args))))
1620
end
17-
match = mthds[1]
21+
match = only(mthds)::Core.MethodMatch
1822

1923
mi = Core.Compiler.specialize_method(match)
20-
ci = Core.Compiler.retrieve_code_info(mi)
24+
ci = Core.Compiler.retrieve_code_info(mi, world)
2125

2226
ci′ = copy(ci)
2327
ci′.edges = MethodInstance[mi]
2428

25-
r = transform!(ci′, mi.def, length(args) - 1, match.sparams, N)
26-
if isa(r, Expr)
27-
return r
28-
end
29+
ci′ = diffract_transform!(ci′, mi.def, length(args) - 1, match.sparams, N)
2930

30-
ci′.ssavaluetypes = length(ci′.code)
31-
ci′.ssaflags = UInt8[0 for i=1:length(ci′.code)]
32-
ci′.method_for_inference_limit_heuristics = match.method
33-
ci′
31+
return ci′
3432
end
3533

3634
# This relies on PartialStruct to infer well
@@ -402,18 +400,10 @@ end
402400
ChainRulesCore.backing(::ZeroTangent) = ZeroTangent()
403401
ChainRulesCore.backing(::NoTangent) = NoTangent()
404402

405-
function reload()
406-
Core.eval(Diffractor, quote
407-
function (ff::∂⃖recurse)(args...)
408-
$(Expr(:meta, :generated_only))
409-
$(Expr(:meta,
410-
:generated,
411-
Expr(:new,
412-
Core.GeneratedFunctionStub,
413-
:perform_optic_transform,
414-
Core.svec(:ff, :args),
415-
Core.svec())))
416-
end
417-
end)
403+
let ex = :(function (ff::∂⃖recurse)(args...)
404+
$(Expr(:meta, :generated_only))
405+
$(Expr(:meta, :generated, perform_optic_transform))
406+
end)
407+
push!(GENERATORS, ex)
408+
Core.eval(@__MODULE__, ex)
418409
end
419-
reload()

src/stage1/recurse.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -256,12 +256,12 @@ function sptypes(sparams)
256256
end
257257
end
258258

259-
function transform!(ci, meth, nargs, sparams, N)
259+
function diffract_transform!(ci, meth, nargs, sparams, N)
260260
code = ci.code
261261
cfg = compute_basic_blocks(code)
262-
slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
263-
slotflags = UInt8[(0x00 for i = 1:2)..., ci.slotflags...]
264-
slottypes = ci.slottypes === nothing ? nothing : UInt8[(Any for i = 1:2)..., ci.slottypes...]
262+
ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
263+
ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...]
264+
ci.slottypes = ci.slottypes === nothing ? Any[Any for _ in 1:length(ci.slotflags)] : Any[Any, Any, ci.slottypes...]
265265

266266
meta = Expr[]
267267
ir = IRCode(Core.Compiler.InstructionStream(code, Any[],
@@ -273,19 +273,19 @@ function transform!(ci, meth, nargs, sparams, N)
273273
domtree = construct_domtree(ir.cfg.blocks)
274274
defuse_insts = scan_slot_def_use(Int(meth.nargs), ci, ir.stmts.inst)
275275
ci.ssavaluetypes = Any[Any for i = 1:ci.ssavaluetypes]
276-
ir = construct_ssa!(ci, ir, domtree, defuse_insts, Any[Any for i = 1:length(slotnames)], Core.Compiler.OptimizerLattice())
276+
ir = construct_ssa!(ci, ir, domtree, defuse_insts, ci.slottypes, Core.Compiler.OptimizerLattice())
277277
ir = compact!(ir)
278278

279279
nfixedargs = meth.isva ? meth.nargs - 1 : meth.nargs
280280
meth.isva || @assert nfixedargs == nargs+1
281281

282282
ir = diffract_ir!(ir, ci, meth, sparams, nargs, N)
283283

284-
Core.Compiler.replace_code_newstyle!(ci, ir, nargs+1)
284+
Core.Compiler.replace_code_newstyle!(ci, ir)
285+
285286
ci.ssavaluetypes = length(ci.code)
286-
ci.slotnames = slotnames
287-
ci.slotflags = slotflags
288-
ci.slottypes = slottypes
287+
ci.ssaflags = UInt8[0x00 for i=1:length(ci.code)]
288+
ci.method_for_inference_limit_heuristics = meth
289289

290-
ci
290+
return ci
291291
end

src/stage1/recurse_fwd.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,24 @@ function ∂☆nomethd(@nospecialize(args))
2828
throw(MethodError(primal(args[1]), map(primal, Base.tail(args))))
2929
end
3030

31-
function perform_fwd_transform(@nospecialize(ff::Type{∂☆recurse{N}}), @nospecialize(args)) where {N}
31+
function perform_fwd_transform(world::UInt, source::LineNumberNode,
32+
@nospecialize(ff::Type{∂☆recurse{N}}), @nospecialize(args)) where {N}
3233
if all(x->x <: ZeroBundle, args)
3334
return :(∂☆passthrough(args))
3435
end
3536

3637
# Check if we have an rrule for this function
3738
sig = Tuple{map(π, args)...}
38-
mthds = Base._methods_by_ftype(sig, -1, typemax(UInt))
39-
if length(mthds) != 1
40-
return :(∂☆nomethd(args))
39+
mthds = Base._methods_by_ftype(sig, -1, world)
40+
if mthds === nothing || length(mthds) != 1
41+
# Core.println("[perform_fwd_transform] ", sig, " => ", mthds)
42+
stub = Core.GeneratedFunctionStub(identity, Core.svec(:ff, :args), Core.svec())
43+
return stub(world, source, :(∂☆nomethd(args)))
4144
end
42-
match = mthds[1]
45+
match = only(mthds)::Core.MethodMatch
4346

4447
mi = Core.Compiler.specialize_method(match)
45-
ci = Core.Compiler.retrieve_code_info(mi)
48+
ci = Core.Compiler.retrieve_code_info(mi, world)
4649

4750
ci′ = copy(ci)
4851
ci′.edges = MethodInstance[mi]
@@ -59,16 +62,13 @@ function perform_fwd_transform(@nospecialize(ff::Type{∂☆recurse{N}}), @nospe
5962
ci′.slotflags = slotflags
6063
ci′.slottypes = slottypes
6164

62-
ci′
65+
return ci′
6366
end
6467

65-
@eval function (ff::∂☆recurse)(args...)
66-
$(Expr(:meta, :generated_only))
67-
$(Expr(:meta,
68-
:generated,
69-
Expr(:new,
70-
Core.GeneratedFunctionStub,
71-
:perform_fwd_transform,
72-
Core.svec(:ff, :args),
73-
Core.svec())))
68+
let ex = :(function (ff::∂☆recurse)(args...)
69+
$(Expr(:meta, :generated_only))
70+
$(Expr(:meta, :generated, perform_fwd_transform))
71+
end)
72+
push!(GENERATORS, ex)
73+
Core.eval(@__MODULE__, ex)
7474
end

src/stage1/termination.jl

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,72 @@
11
if :recursion_relation in fieldnames(Method)
22

33
first(methods(Diffractor.∂⃖recurse{1}())).recursion_relation = function(method1, method2, parent_sig, new_sig)
4-
# Recursion from a higher to a lower order is always allowed
5-
parent_order = parent_sig.parameters[1].parameters[1]
6-
child_order = new_sig.parameters[1].parameters[1]
7-
#@Core.Main.Base.show (parent_order, child_order)
8-
if parent_order > child_order
9-
return true
10-
end
11-
wrapped_parent_sig = Tuple{parent_sig.parameters[2:end]...}
12-
wrapped_new_sig = Tuple{parent_sig.parameters[2:end]...}
13-
if method2 !== nothing && isdefined(method2, :recursion_relation)
14-
# TODO: What if method2 is itself a generated function.
15-
return method2.recursion_relation(method2, nothing, wrapped_parent_sig, wrapped_new_sig)
16-
end
17-
return Core.Compiler.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1)
4+
# Recursion from a higher to a lower order is always allowed
5+
parent_order = parent_sig.parameters[1].parameters[1]
6+
child_order = new_sig.parameters[1].parameters[1]
7+
#@Core.Main.Base.show (parent_order, child_order)
8+
if parent_order > child_order
9+
return true
10+
end
11+
wrapped_parent_sig = Tuple{parent_sig.parameters[2:end]...}
12+
wrapped_new_sig = Tuple{parent_sig.parameters[2:end]...}
13+
if method2 !== nothing && isdefined(method2, :recursion_relation)
14+
# TODO: What if method2 is itself a generated function.
15+
return method2.recursion_relation(method2, nothing, wrapped_parent_sig, wrapped_new_sig)
16+
end
17+
return Core.Compiler.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1)
1818
end
1919

2020
first(methods(PrimeDerivativeBack(sin))).recursion_relation = function(method1, method2, parent_sig, new_sig)
21-
# Recursion from a higher to a lower order is always allowed
22-
parent_order = parent_sig.parameters[1].parameters[1]
23-
child_order = new_sig.parameters[1].parameters[1]
24-
#@Core.Main.Base.show (parent_order, child_order)
25-
if parent_order > child_order
26-
return true
27-
end
28-
wrapped_parent_sig = Tuple{parent_sig.parameters[2:end]...}
29-
wrapped_new_sig = Tuple{parent_sig.parameters[2:end]...}
30-
return Core.Compiler.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1)
21+
# Recursion from a higher to a lower order is always allowed
22+
parent_order = parent_sig.parameters[1].parameters[1]
23+
child_order = new_sig.parameters[1].parameters[1]
24+
#@Core.Main.Base.show (parent_order, child_order)
25+
if parent_order > child_order
26+
return true
27+
end
28+
wrapped_parent_sig = Tuple{parent_sig.parameters[2:end]...}
29+
wrapped_new_sig = Tuple{parent_sig.parameters[2:end]...}
30+
return Core.Compiler.type_more_complex(new_sig, parent_sig, Core.svec(parent_sig), 1, 3, length(method1.sig.parameters)+1)
3131
end
3232

3333
which(Tuple{∂⃖{N}, T, Vararg{Any}} where {T,N}).recursion_relation = function(_, _, parent_sig, new_sig)
34-
# Any actual recursion will always be caught be one of the functions we're
35-
# recursing into.
36-
return isa(Base.unwrap_unionall(parent_sig.parameters[1].parameters[1]), Int) &&
34+
# Any actual recursion will always be caught be one of the functions we're
35+
# recursing into.
36+
return isa(Base.unwrap_unionall(parent_sig.parameters[1].parameters[1]), Int) &&
3737
isa(Base.unwrap_unionall(new_sig.parameters[1].parameters[1]), Int)
3838
end
3939

4040
which(Tuple{∂⃖{N}, ∂⃖{1}, Vararg{Any}} where {N}).recursion_relation = function(_, _, parent_sig, new_sig)
41-
# Allowed as long as both parent and new sig have concrete integers. In that
42-
# case, actual recursion will be caught elsewhere.
43-
return isa(Base.unwrap_unionall(parent_sig.parameters[1].parameters[1]), Int) &&
44-
isa(Base.unwrap_unionall(new_sig.parameters[1].parameters[1]), Int)
41+
# Allowed as long as both parent and new sig have concrete integers. In that
42+
# case, actual recursion will be caught elsewhere.
43+
return isa(Base.unwrap_unionall(parent_sig.parameters[1].parameters[1]), Int) &&
44+
isa(Base.unwrap_unionall(new_sig.parameters[1].parameters[1]), Int)
4545
end
4646

4747
for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆recurse{N}, Vararg{Any}} where {N}, nothing, -1, Base.get_world_counter())
48-
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
49-
# Recursion from a higher to a lower order is always allowed
50-
parent_order = parent_sig.parameters[1].parameters[1]
51-
child_order = new_sig.parameters[1].parameters[1]
52-
#@Core.Main.Base.show (parent_order, child_order)
53-
if parent_order > child_order
54-
return true
55-
end
56-
@show (parent_sig, new_sig)
57-
return false
58-
end
48+
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
49+
# Recursion from a higher to a lower order is always allowed
50+
parent_order = parent_sig.parameters[1].parameters[1]
51+
child_order = new_sig.parameters[1].parameters[1]
52+
if parent_order > child_order
53+
return true
54+
end
55+
Core.Compiler.@show (parent_sig, new_sig)
56+
return false
57+
end
5958
end
6059

6160
for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆internal{N}, Vararg{Any}} where {N}, nothing, -1, Base.get_world_counter())
62-
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
63-
return true
64-
end
61+
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
62+
return true
63+
end
6564
end
6665

6766
for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆{N}, Vararg{Any}} where {N}, nothing, -1, Base.get_world_counter())
68-
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
69-
return true
70-
end
67+
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
68+
return true
69+
end
7170
end
7271

7372
end

0 commit comments

Comments
 (0)