Skip to content

Commit 87731ed

Browse files
authored
remove the hack (#129)
1 parent 16e649d commit 87731ed

File tree

5 files changed

+7
-23
lines changed

5 files changed

+7
-23
lines changed

src/Diffractor.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,4 @@ include("debugutils.jl")
3939

4040
include("stage1/termination.jl")
4141

42-
function reload()
43-
@info "reloading Diffractor generators"
44-
for generator in GENERATORS
45-
Core.eval(@__MODULE__, generator)
46-
end
47-
end
48-
4942
end

src/stage1/generated.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -395,10 +395,7 @@ end
395395
ChainRulesCore.backing(::ZeroTangent) = ZeroTangent()
396396
ChainRulesCore.backing(::NoTangent) = NoTangent()
397397

398-
let ex = :(function (ff::∂⃖recurse)(args...)
399-
$(Expr(:meta, :generated_only))
400-
$(Expr(:meta, :generated, perform_optic_transform))
401-
end)
402-
push!(GENERATORS, ex)
403-
Core.eval(@__MODULE__, ex)
398+
@eval function (ff::∂⃖recurse)(args...)
399+
$(Expr(:meta, :generated_only))
400+
$(Expr(:meta, :generated, perform_optic_transform))
404401
end

src/stage1/recurse_fwd.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,7 @@ function perform_fwd_transform(world::UInt, source::LineNumberNode,
5050
return fwd_transform(ci, mi, length(args)-1, N)
5151
end
5252

53-
let ex = :(function (ff::∂☆recurse)(args...)
54-
$(Expr(:meta, :generated_only))
55-
$(Expr(:meta, :generated, perform_fwd_transform))
56-
end)
57-
push!(GENERATORS, ex)
58-
Core.eval(@__MODULE__, ex)
53+
@eval function (ff::∂☆recurse)(args...)
54+
$(Expr(:meta, :generated_only))
55+
$(Expr(:meta, :generated, perform_fwd_transform))
5956
end

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,7 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
226226
@test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] [0.2338, -0.0177, -0.0661] atol=1e-3
227227
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],)
228228

229-
# XXX the world-age limitation is preventing this test from passing
230-
# @test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) # frule_via_ad
229+
@test_broken gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) # frule_via_ad
231230
exp_log(x) = exp(log(x))
232231
@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],)
233232
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75])

test/stage2_fwd.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ module stage2_fwd
1212
myminus(a, b) = a - b
1313
ChainRulesCore.@scalar_rule myminus(x, y) (true, -1)
1414

15-
Diffractor.reload() # XXX we should remove this
16-
1715
self_minus(a) = myminus(a, a)
1816
let self_minus′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}, 2)
1917
@test isa(self_minus′′, Core.OpaqueClosure{Tuple{Float64}, Float64})

0 commit comments

Comments
 (0)