-
Notifications
You must be signed in to change notification settings - Fork 78
Open
Description
This is neither a bug (I think) nor a feature request, I need a bit of help implementing forward- and reverse rules for an implicitly calculated quantity. Specifically, let p
be calculated as the solution to
where
This defines the primal and a utility function to calculate the partials:
using Roots, Distributions, Enzyme
using .EnzymeRules
function solvep(A::T, B::T) where {T<:AbstractFloat}
condition(p) = cdf(Normal(), p * A + (1 - p) * B) - p
find_zero(condition, (0.0, 1.0))
end
function partialsp(A, B, p)
H′ = pdf(Normal(), p * A + (1 - p) * B)
C = H′ / (1 - H′ * (B - A))
(∂p∂A = C * (1 - p), ∂p∂B = C * p)
end
and I thought this would work:
function EnzymeRules.forward(config::FwdConfig, func::Const{typeof(solvep)},
::Type{<:Duplicated}, A::Duplicated, B::Duplicated)
println("using custom rule")
p = solvep(A.val, B.val)
(; ∂p∂A, ∂p∂B) = partialsp(A.val, B.val, p)
Duplicated(p, ∂p∂A * A.dval + ∂p∂B * B.dval)
end
but
autodiff(ForwardWithPrimal, solvep, Duplicated(1.0, 2.0), Duplicated(1.0, 1.0))
gives me a long machine instriction dump with a stacktrace
Stacktrace:
[1] post_optimze!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/7Bzkr/src/compiler/optimize.jl:744
[2] post_optimze!
@ ~/.julia/packages/Enzyme/7Bzkr/src/compiler/optimize.jl:722 [inlined]
[3] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget{GPUCompiler.NativeCompilerTarget}, Enzyme.Compiler.EnzymeCompilerParams{Enzyme.Compiler.PrimalCompilerParams}}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/7Bzkr/src/compiler.jl:5722
[4] _thunk
@ ~/.julia/packages/Enzyme/7Bzkr/src/compiler.jl:5695 [inlined]
[5] cached_compilation
@ ~/.julia/packages/Enzyme/7Bzkr/src/compiler.jl:5749 [inlined]
[6] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, edges::Vector{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/7Bzkr/src/compiler.jl:5863
[7] thunk_generator(world::UInt64, source::Union{…}, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type, strongzero::Type)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/7Bzkr/src/compiler.jl:6056
[8] autodiff
@ ~/.julia/packages/Enzyme/7Bzkr/src/Enzyme.jl:654 [inlined]
[9] autodiff
@ ~/.julia/packages/Enzyme/7Bzkr/src/Enzyme.jl:558 [inlined]
[10] autodiff(::ForwardMode{true, FFIABI, false, false, false}, ::typeof(solvep), ::Duplicated{Float64}, ::Duplicated{Float64})
@ Enzyme ~/.julia/packages/Enzyme/7Bzkr/src/Enzyme.jl:530
[11] top-level scope
@ REPL[13]:1
Some type information was truncated. Use `show(err)` to see complete types.
so probably I am making a mistake. I am using
[31c24e10] Distributions v0.25.120
[7da242da] Enzyme v0.13.67
[57b37032] ImplicitDifferentiation v0.9.0
[f2b01f46] Roots v2.2.8
and Julia 1.11.6.
Metadata
Metadata
Assignees
Labels
No labels