Skip to content

Commit da2c0bb

Browse files
authored
Merge pull request #168 from JuliaDiff/sf/globalref_types
Improve type inference of non-`const` GlobalRef's
2 parents 4746542 + e0963a9 commit da2c0bb

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

src/codegen/forward_demand.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,11 +256,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
256256
# TODO: Should we remember whether the callbacks wanted the arg?
257257
return transform!(ir, arg, order)
258258
elseif isa(arg, GlobalRef)
259-
if !isconst(arg)
260-
# Non-const GlabalRefs need to need to be accessed as seperate statements
261-
arg = insert_node!(ir, ssa, NewInstruction(arg, Any))
262-
end
263-
259+
@assert isconst(arg)
264260
return insert_node!(ir, ssa, NewInstruction(Expr(:call, ZeroBundle{order}, arg), Any))
265261
elseif isa(arg, QuoteNode)
266262
return ZeroBundle{order}(arg.value)
@@ -302,7 +298,15 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
302298
# TODO: New PiNode that discriminates based on primal?
303299
inst[:inst] = maparg(stmt.val, SSAValue(ssa), order)
304300
inst[:type] = Any
305-
elseif isa(stmt, GlobalRef) || isa(stmt, SSAValue) || isa(stmt, QuoteNode)
301+
elseif isa(stmt, GlobalRef)
302+
if !isconst(stmt)
303+
# Non-const GlobalRefs need to need to be accessed as seperate statements
304+
stmt = insert_node!(ir, ssa, NewInstruction(inst))
305+
end
306+
307+
inst[:inst] = Expr(:call, ZeroBundle{order}, stmt)
308+
inst[:type] = Any
309+
elseif isa(stmt, SSAValue) || isa(stmt, QuoteNode)
306310
inst[:inst] = maparg(stmt, SSAValue(ssa), order)
307311
inst[:type] = Any
308312
elseif isa(stmt, Expr) || isa(stmt, PhiNode) || isa(stmt, PhiCNode) ||

test/forward_diff_no_inf.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ module forward_diff_no_inf
2828
Diffractor.forward_diff_no_inf!(ir, [Core.SSAValue(1) => 1]; transform! = identity_transform!)
2929
ir2 = Core.Compiler.compact!(ir)
3030
Core.Compiler.verify_ir(ir2) # This would error if we were not handling nonconst globals correctly
31+
# Assert that the reference to `Main._coeff` is properly typed
32+
stmt_idx = findfirst(stmt -> isa(stmt[:inst], GlobalRef), collect(ir2.stmts))
33+
stmt = ir2.stmts[stmt_idx]
34+
@test stmt[:inst].name == :_coeff
35+
@test stmt[:type] == Float64
3136
f = Core.OpaqueClosure(ir2; do_compile=false)
3237
@test f(3.5) == 28.0
3338
end

0 commit comments

Comments
 (0)