Skip to content

help with R²→R implicit rule #2507

@tpapp

Description

@tpapp

This is neither a bug (I think) nor a feature request, I need a bit of help implementing forward- and reverse rules for an implicitly calculated quantity. Specifically, let p be calculated as the solution to

$$p = H((1-p) \cdot A + p \cdot B)$$

where $H$ is the cdf of the $\mathrm{Normal}(0, 1)$ distribution. Note that everything is a scalar.

This defines the primal and a utility function to calculate the partials:

using Roots, Distributions, Enzyme
using .EnzymeRules

function solvep(A::T, B::T) where {T<:AbstractFloat}
    condition(p) = cdf(Normal(), p * A + (1 - p) * B) - p
    find_zero(condition, (0.0, 1.0))
end

function partialsp(A, B, p)
    H′ = pdf(Normal(), p * A + (1 - p) * B)
    C = H′ / (1 - H′ * (B - A))
    (∂p∂A = C * (1 - p), ∂p∂B = C * p)
end

and I thought this would work:

function EnzymeRules.forward(config::FwdConfig, func::Const{typeof(solvep)},
                             ::Type{<:Duplicated}, A::Duplicated, B::Duplicated)
    println("using custom rule")
    p = solvep(A.val, B.val)
    (; ∂p∂A, ∂p∂B) = partialsp(A.val, B.val, p)
    Duplicated(p, ∂p∂A * A.dval + ∂p∂B * B.dval)
end

but

autodiff(ForwardWithPrimal, solvep, Duplicated(1.0, 2.0), Duplicated(1.0, 1.0))

gives me a long machine instriction dump with a stacktrace

Stacktrace:
  [1] post_optimze!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7Bzkr/src/compiler/optimize.jl:744
  [2] post_optimze!
    @ ~/.julia/packages/Enzyme/7Bzkr/src/compiler/optimize.jl:722 [inlined]
  [3] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget{GPUCompiler.NativeCompilerTarget}, Enzyme.Compiler.EnzymeCompilerParams{Enzyme.Compiler.PrimalCompilerParams}}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7Bzkr/src/compiler.jl:5722
  [4] _thunk
    @ ~/.julia/packages/Enzyme/7Bzkr/src/compiler.jl:5695 [inlined]
  [5] cached_compilation
    @ ~/.julia/packages/Enzyme/7Bzkr/src/compiler.jl:5749 [inlined]
  [6] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7Bzkr/src/compiler.jl:5863
  [7] thunk_generator(world::UInt64, source::Union{…}, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type, strongzero::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/7Bzkr/src/compiler.jl:6056
  [8] autodiff
    @ ~/.julia/packages/Enzyme/7Bzkr/src/Enzyme.jl:654 [inlined]
  [9] autodiff
    @ ~/.julia/packages/Enzyme/7Bzkr/src/Enzyme.jl:558 [inlined]
 [10] autodiff(::ForwardMode{true, FFIABI, false, false, false}, ::typeof(solvep), ::Duplicated{Float64}, ::Duplicated{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/7Bzkr/src/Enzyme.jl:530
 [11] top-level scope
    @ REPL[13]:1
Some type information was truncated. Use `show(err)` to see complete types.

so probably I am making a mistake. I am using

  [31c24e10] Distributions v0.25.120
  [7da242da] Enzyme v0.13.67
  [57b37032] ImplicitDifferentiation v0.9.0
  [f2b01f46] Roots v2.2.8

and Julia 1.11.6.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions