Skip to content

Commit 640b315

Browse files
committed
1 parent 4438c41 commit 640b315

File tree

3 files changed

+17
-22
lines changed

3 files changed

+17
-22
lines changed

src/codegen/reverse.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Codegen shared by both stage1 and stage2
22

3-
function make_opaque_closure(interp, typ, name, meth_nargs, isva, lno, cis, revs...)
3+
function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, cis, revs...)
44
if interp !== nothing
55
cis.inferred = true
66
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
@@ -112,7 +112,8 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
112112
opaque_ci
113113
end
114114

115-
nfixedargs = meth.isva ? meth.nargs - 1 : meth.nargs
115+
nfixedargs = Int(meth.nargs)
116+
meth.isva && (nfixedargs -= 1)
116117

117118
extra_slotnames = Symbol[]
118119
extra_slotflags = UInt8[]
@@ -158,7 +159,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
158159
# TODO: Can we use the same method for each 2nd order of the transform
159160
# (except the last and the first one)
160161
for nc = 1:2:n_closures
161-
arg_accums = Union{Nothing, Vector{Any}}[nothing for i = 1:(meth.nargs)]
162+
arg_accums = Union{Nothing, Vector{Any}}[nothing for i = 1:Int(meth.nargs)]
162163
accums = Union{Nothing, Vector{Any}}[nothing for i = 1:length(ir.stmts)]
163164

164165
opaque_ci = opaque_cis[nc]
@@ -376,7 +377,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
376377
lno = LineNumberNode(1, :none)
377378
next_oc = insert_node_rev!(make_opaque_closure(interp, Tuple{(Any for i = 1:nargs+1)...},
378379
cname(nc+1, N, meth.name),
379-
meth.nargs,
380+
Int(meth.nargs),
380381
meth.isva,
381382
lno,
382383
opaque_cis[nc+1],

src/stage1/generated.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,9 @@ function perform_optic_transform(@nospecialize(ff::Type{∂⃖recurse{N}}), @nos
2222
ci′ = copy(ci)
2323
ci′.edges = MethodInstance[mi]
2424

25-
r = transform!(ci′, mi.def, length(args) - 1, match.sparams, N)
26-
if isa(r, Expr)
27-
return r
28-
end
25+
ci′ = diffract_transform!(ci′, mi.def, length(args) - 1, match.sparams, N)
2926

30-
ci′.ssavaluetypes = length(ci′.code)
31-
ci′.ssaflags = UInt8[0 for i=1:length(ci′.code)]
32-
ci′.method_for_inference_limit_heuristics = match.method
33-
ci′
27+
return ci′
3428
end
3529

3630
# This relies on PartialStruct to infer well

src/stage1/recurse.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -256,12 +256,12 @@ function sptypes(sparams)
256256
end
257257
end
258258

259-
function transform!(ci, meth, nargs, sparams, N)
259+
function diffract_transform!(ci, meth, nargs, sparams, N)
260260
code = ci.code
261261
cfg = compute_basic_blocks(code)
262-
slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
263-
slotflags = UInt8[(0x00 for i = 1:2)..., ci.slotflags...]
264-
slottypes = ci.slottypes === nothing ? nothing : UInt8[(Any for i = 1:2)..., ci.slottypes...]
262+
ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
263+
ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...]
264+
ci.slottypes = ci.slottypes === nothing ? Any[Any for _ in 1:length(ci.slotflags)] : Any[Any, Any, ci.slottypes...]
265265

266266
meta = Expr[]
267267
ir = IRCode(Core.Compiler.InstructionStream(code, Any[],
@@ -273,19 +273,19 @@ function transform!(ci, meth, nargs, sparams, N)
273273
domtree = construct_domtree(ir.cfg.blocks)
274274
defuse_insts = scan_slot_def_use(Int(meth.nargs), ci, ir.stmts.inst)
275275
ci.ssavaluetypes = Any[Any for i = 1:ci.ssavaluetypes]
276-
ir = construct_ssa!(ci, ir, domtree, defuse_insts, Any[Any for i = 1:length(slotnames)], Core.Compiler.OptimizerLattice())
276+
ir = construct_ssa!(ci, ir, domtree, defuse_insts, ci.slottypes, Core.Compiler.OptimizerLattice())
277277
ir = compact!(ir)
278278

279279
nfixedargs = meth.isva ? meth.nargs - 1 : meth.nargs
280280
meth.isva || @assert nfixedargs == nargs+1
281281

282282
ir = diffract_ir!(ir, ci, meth, sparams, nargs, N)
283283

284-
Core.Compiler.replace_code_newstyle!(ci, ir, nargs+1)
284+
Core.Compiler.replace_code_newstyle!(ci, ir)
285+
285286
ci.ssavaluetypes = length(ci.code)
286-
ci.slotnames = slotnames
287-
ci.slotflags = slotflags
288-
ci.slottypes = slottypes
287+
ci.ssaflags = UInt8[0x00 for i=1:length(ci.code)]
288+
ci.method_for_inference_limit_heuristics = meth
289289

290-
ci
290+
return ci
291291
end

0 commit comments

Comments
 (0)