Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Base.getindex(o::OpticBundle, i::Int) = i == 1 ? o.x :
Base.iterate(o::OpticBundle) = (o.x, nothing)
Base.iterate(o::OpticBundle, ::Nothing) = (o.clos, missing)
Base.iterate(o::OpticBundle, ::Missing) = nothing
Base.length(o::OpticBundle) = 2

# Desturucture using `getfield` rather than iterate to make
# inference happier
Expand Down Expand Up @@ -227,7 +228,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N}
end

function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) where {T}
∂⃖{1}()(f, args...)
∂⃖{1}()(f, args...) |> Tuple{Any, Any}
end

@Base.pure function (::∂⃖{1})(::typeof(Core.apply_type), head, args...)
Expand Down
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ x43 = rand(10, 10)
@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x43) isa Tuple{Matrix{Float64}}

# PR # 45 - Calling back into AD from ChainRules
y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2)
r45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2)
@test r45 isa Tuple
y45, back45 = r45
@test y45 ≈ 2.0
@test back45(1) == (ZeroTangent(), 1.0)

Expand Down