Skip to content

Commit ac7bce4

Browse files
committed
Use rrule for getproperty(::Module, ::Symbol)
1 parent 7f126aa commit ac7bce4

File tree

2 files changed

+8
-28
lines changed

2 files changed

+8
-28
lines changed

src/codegen/reverse.jl

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,6 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
289289
if isa(stmt, Core.ReturnNode)
290290
accum!(stmt.val, Argument(2))
291291
current_env = nothing
292-
elseif is_global_access(ir, stmt)
293-
# Treat it as a GlobalRef, dropping gradients.
294292
elseif isexpr(stmt, :call) || isexpr(stmt, :invoke)
295293
Δ = do_accum(SSAValue(i))
296294
callee = retrieve_ctx_obj(current_env, i)
@@ -455,9 +453,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
455453
end
456454
stmt = urs[]
457455

458-
if is_global_access(ir, stmt)
459-
fwds[i] = ZeroTangent()
460-
elseif isexpr(stmt, :call)
456+
if isexpr(stmt, :call)
461457
callee = insert_node_here!(Expr(:call, getfield, Argument(1), i))
462458
pushfirst!(stmt.args, callee)
463459
call = insert_node_here!(stmt)
@@ -569,7 +565,7 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
569565
if isexpr(stmt, :(=))
570566
stmt = stmt.args[2]
571567
end
572-
if isexpr(stmt, :call) && !is_global_access(compact, stmt)
568+
if isexpr(stmt, :call)
573569
compact[SSAValue(idx)] = Expr(:call, ∂⃖{N}(), stmt.args...)
574570
if isexpr(orig_stmt, :(=))
575571
orig_stmt.args[2] = stmt
@@ -681,25 +677,3 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
681677

682678
return ir
683679
end
684-
685-
eval_globalref(x) = x
686-
function eval_globalref(ref::GlobalRef)
687-
isdefined(ref.mod, ref.name) || return nothing
688-
getproperty(ref.mod, ref.name)
689-
end
690-
ssa_def(ir, idx::SSAValue) = ssa_def(ir, ir[idx][:inst])
691-
ssa_def(ir, def) = def
692-
693-
function is_global_access(ir::Union{IRCode,IncrementalCompact}, stmt)
694-
isexpr(stmt, :call, 3) || return false
695-
f = ssa_def(ir, stmt.args[1])
696-
if isa(f, GlobalRef)
697-
f.name === :getproperty || return false
698-
f = eval_globalref(f)
699-
end
700-
f === getproperty || return false
701-
from = eval_globalref(ssa_def(ir, stmt.args[2]))
702-
isa(from, Module) || return false
703-
name = stmt.args[3]
704-
isa(name, QuoteNode) && isa(name.value, Symbol)
705-
end

src/extra_rules.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,12 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk},
268268
val, Δ->(NoTangent(), NoTangent(), Δ)
269269
end
270270

271+
# XXX: We should instead skip differentiation in the IR.
272+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getproperty), mod::Module, name::Symbol)
273+
val = getproperty(mod, name)
274+
val, Δ->(NoTangent(), NoTangent(), NoTangent())
275+
end
276+
271277
Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581
272278

273279
# Avoid https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495

0 commit comments

Comments
 (0)