Skip to content

Commit dc6746e

Browse files
committed
Treat getproperty(::Module, ::Symbol) like GlobalRefs
1 parent 7b7b757 commit dc6746e

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

src/codegen/reverse.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,8 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
289289
if isa(stmt, Core.ReturnNode)
290290
accum!(stmt.val, Argument(2))
291291
current_env = nothing
292+
elseif is_global_access(ir, stmt)
293+
# Treat it as a GlobalRef, dropping gradients.
292294
elseif isexpr(stmt, :call) || isexpr(stmt, :invoke)
293295
Δ = do_accum(SSAValue(i))
294296
callee = retrieve_ctx_obj(current_env, i)
@@ -453,7 +455,9 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
453455
end
454456
stmt = urs[]
455457

456-
if isexpr(stmt, :call)
458+
if is_global_access(ir, stmt)
459+
fwds[i] = ZeroTangent()
460+
elseif isexpr(stmt, :call)
457461
callee = insert_node_here!(Expr(:call, getfield, Argument(1), i))
458462
pushfirst!(stmt.args, callee)
459463
call = insert_node_here!(stmt)
@@ -565,7 +569,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
565569
if isexpr(stmt, :(=))
566570
stmt = stmt.args[2]
567571
end
568-
if isexpr(stmt, :call)
572+
if isexpr(stmt, :call) && !is_global_access(compact, stmt)
569573
compact[SSAValue(idx)] = Expr(:call, ∂⃖{N}(), stmt.args...)
570574
if isexpr(orig_stmt, :(=))
571575
orig_stmt.args[2] = stmt
@@ -677,3 +681,18 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
677681

678682
return ir
679683
end
684+
685+
eval_globalref(x) = x
686+
eval_globalref(x::GlobalRef) = getglobal(x.mod, x.name)
687+
ssa_def(ir, idx::SSAValue) = ssa_def(ir, ir[idx][:inst])
688+
ssa_def(ir, def) = def
689+
690+
function is_global_access(ir::Union{IRCode,IncrementalCompact}, stmt)
691+
isexpr(stmt, :call, 3) || return false
692+
f = eval_globalref(ssa_def(ir, stmt.args[1]))
693+
f === getproperty || return false
694+
from = eval_globalref(ssa_def(ir, stmt.args[2]))
695+
isa(from, Module) || return false
696+
name = stmt.args[3]
697+
isa(name, QuoteNode) && isa(name.value, Symbol)
698+
end

0 commit comments

Comments
 (0)