Skip to content

Commit 16e649d

Browse files
authored
NFC cosmetic changes (#128)
1 parent 84ea690 commit 16e649d

File tree

4 files changed

+40
-33
lines changed

4 files changed

+40
-33
lines changed

src/codegen/forward.jl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
1-
function transform_fwd!(ci, meth, nargs, sparams, N)
1+
function fwd_transform(ci, args...)
2+
newci = copy(ci)
3+
fwd_transform!(newci, args...)
4+
return newci
5+
end
6+
7+
function fwd_transform!(ci, mi, nargs, N)
28
new_code = Any[]
39
new_codelocs = Any[]
410
ssa_mapping = Int[]
511
loc_mapping = Int[]
612

7-
function emit!(stmt)
13+
function emit!(@nospecialize stmt)
814
(isexpr(stmt, :call) || isexpr(stmt, :(=)) || isexpr(stmt, :new)) || return stmt
915
push!(new_code, stmt)
1016
push!(new_codelocs, isempty(new_codelocs) ? 0 : new_codelocs[end])
11-
SSAValue(length(new_code))
17+
return SSAValue(length(new_code))
1218
end
1319

14-
function mapstmt!(stmt)
20+
function mapstmt!(@nospecialize stmt)
1521
if isexpr(stmt, :(=))
1622
return Expr(stmt.head, emit!(mapstmt!(stmt.args[1])), emit!(mapstmt!(stmt.args[2])))
1723
elseif isexpr(stmt, :call)
@@ -44,7 +50,7 @@ function transform_fwd!(ci, meth, nargs, sparams, N)
4450
elseif isa(stmt, GotoIfNot)
4551
return GotoIfNot(emit!(Expr(:call, primal, emit!(mapstmt!(stmt.cond)))), stmt.dest)
4652
elseif isexpr(stmt, :static_parameter)
47-
return ZeroBundle{N}(sparams[stmt.args[1]])
53+
return ZeroBundle{N}(mi.sparam_vals[stmt.args[1]::Int])
4854
elseif isexpr(stmt, :foreigncall)
4955
return Expr(:call, error, "Attempted to AD a foreigncall. Missing rule?")
5056
elseif isexpr(stmt, :meta) || isexpr(stmt, :inbounds)
@@ -56,9 +62,11 @@ function transform_fwd!(ci, meth, nargs, sparams, N)
5662
end
5763
end
5864

59-
for i = 1:meth.nargs
60-
if meth.isva && i == meth.nargs
61-
args = map(i:(nargs+1)) do j
65+
meth = mi.def::Method
66+
nargs = Int(meth.nargs)
67+
for i = 1:nargs
68+
if meth.isva && i == nargs
69+
args = map(i:(nargs+1)) do j::Int
6270
emit!(Expr(:call, getfield, SlotNumber(2), j))
6371
end
6472
emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, ∂vararg{N}(), args...)))
@@ -83,7 +91,15 @@ function transform_fwd!(ci, meth, nargs, sparams, N)
8391
end
8492
end
8593

94+
ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
95+
ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...]
96+
ci.slottypes = ci.slottypes === nothing ? nothing : Any[Any, Any, ci.slottypes...]
8697
ci.code = new_code
8798
ci.codelocs = new_codelocs
88-
ci
99+
ci.ssavaluetypes = length(new_code)
100+
ci.ssaflags = UInt8[0 for i=1:length(new_code)]
101+
ci.method_for_inference_limit_heuristics = meth
102+
ci.edges = MethodInstance[mi]
103+
104+
return ci
89105
end

src/stage1/generated.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,7 @@ function perform_optic_transform(world::UInt, source::LineNumberNode,
2323
mi = Core.Compiler.specialize_method(match)
2424
ci = Core.Compiler.retrieve_code_info(mi, world)
2525

26-
ci′ = copy(ci)
27-
ci′.edges = MethodInstance[mi]
28-
29-
ci′ = diffract_transform!(ci′, mi.def, length(args) - 1, match.sparams, N)
30-
31-
return ci′
26+
return optic_transform(ci, mi, length(args)-1, N)
3227
end
3328

3429
# This relies on PartialStruct to infer well

src/stage1/recurse.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,16 @@ function sptypes(sparams)
256256
end
257257
end
258258

259-
function diffract_transform!(ci, meth, nargs, sparams, N)
259+
function optic_transform(ci, args...)
260+
newci = copy(ci)
261+
optic_transform!(newci, args...)
262+
return newci
263+
end
264+
265+
function optic_transform!(ci, mi, nargs, N)
260266
code = ci.code
267+
sparams = mi.sparam_vals
268+
261269
cfg = compute_basic_blocks(code)
262270
ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
263271
ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...]
@@ -270,13 +278,15 @@ function diffract_transform!(ci, meth, nargs, sparams, N)
270278
Any[Any for i = 1:2], meta, sptypes(sparams))
271279

272280
# SSA conversion
281+
meth = mi.def::Method
273282
domtree = construct_domtree(ir.cfg.blocks)
274283
defuse_insts = scan_slot_def_use(Int(meth.nargs), ci, ir.stmts.inst)
275284
ci.ssavaluetypes = Any[Any for i = 1:ci.ssavaluetypes]
276285
ir = construct_ssa!(ci, ir, domtree, defuse_insts, ci.slottypes, Core.Compiler.OptimizerLattice())
277286
ir = compact!(ir)
278287

279-
nfixedargs = meth.isva ? meth.nargs - 1 : meth.nargs
288+
nfixedargs = Int(meth.nargs)
289+
meth.isva && (nfixedargs -= 1)
280290
meth.isva || @assert nfixedargs == nargs+1
281291

282292
ir = diffract_ir!(ir, ci, meth, sparams, nargs, N)
@@ -286,6 +296,7 @@ function diffract_transform!(ci, meth, nargs, sparams, N)
286296
ci.ssavaluetypes = length(ci.code)
287297
ci.ssaflags = UInt8[0x00 for i=1:length(ci.code)]
288298
ci.method_for_inference_limit_heuristics = meth
299+
ci.edges = MethodInstance[mi]
289300

290301
return ci
291302
end

src/stage1/recurse_fwd.jl

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,7 @@ function perform_fwd_transform(world::UInt, source::LineNumberNode,
4747
mi = Core.Compiler.specialize_method(match)
4848
ci = Core.Compiler.retrieve_code_info(mi, world)
4949

50-
ci′ = copy(ci)
51-
ci′.edges = MethodInstance[mi]
52-
53-
transform_fwd!(ci′, mi.def, length(args) - 1, match.sparams, N)
54-
55-
ci′.ssavaluetypes = length(ci′.code)
56-
ci′.ssaflags = UInt8[0 for i=1:length(ci′.code)]
57-
ci′.method_for_inference_limit_heuristics = match.method
58-
slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
59-
slotflags = UInt8[(0x00 for i = 1:2)..., ci.slotflags...]
60-
slottypes = ci.slottypes === nothing ? nothing : Any[(Any for i = 1:2)..., ci.slottypes...]
61-
ci′.slotnames = slotnames
62-
ci′.slotflags = slotflags
63-
ci′.slottypes = slottypes
64-
65-
return ci′
50+
return fwd_transform(ci, mi, length(args)-1, N)
6651
end
6752

6853
let ex = :(function (ff::∂☆recurse)(args...)

0 commit comments

Comments
 (0)