-
Notifications
You must be signed in to change notification settings - Fork 91
Open
Description
I'm trying to define an enzyme rule for a single branch in a function. However, Enzyme seems o be applying the general rule when I try defining it.
using Enzyme
import .EnzymeRules: forward
vals =collect(0.0:0.1:1.0)
dvals = Tuple(collect(1.0 for _ in vals))
@noinline function K_m_0(m::T) where {T}
return T(π / 2)
end
function K(m::T) where {T}
if m == 0
return K_m_0(m)
else
return m
end
end
function forward(
config::EnzymeRules.FwdConfig,
func::Const{typeof(K_m_0)},
RT,
m::Annotation{<:Real},
)
T = typeof(m.val)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
return Duplicated(
func.val(m.val),
(m isa Const ? zero(m.val) : T(π/8) * m.dval),
)
else
return BatchDuplicated(
func.val(m.val),
ntuple(
i -> (m isa Const ? zero(m.val) : T(π/8) * m.dval[i]),
Val(EnzymeRules.width(config)),
),
)
end
elseif EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
return (m isa Const ? zero(m.val) : T(π/8) * m.dval)
else
return ntuple(
i -> (m isa Const ? zero(m.val) : T(π/8) * m.dval[i]),
Val(EnzymeRules.width(config)),
)
end
elseif EnzymeRules.needs_primal(config)
return func.val(m.val)
else
return nothing
end
end
function test_forward(vals,dvals)
return [Enzyme.autodiff(Forward,(m,) -> K(m), Duplicated, Duplicated(val,dval)) for (val, dval) in zip(vals, dvals)]
end
test_forward(vals, dvals)
output:
julia> test_forward(vals, dvals)
11-element Vector{Tuple{Float64}}:
(0.0,)
(1.0,)
(1.0,)
(1.0,)
(1.0,)
(1.0,)
(1.0,)
(1.0,)
(1.0,)
(1.0,)
(1.0,)expected output:
julia> test_forward(vals, dvals)
11-element Vector{Tuple{Float64}}:
(1.5707963267948966,)
(1.0,)
(1.0,)
(1.0,)
(1.0,)
(1.0,)
(1.0,)
(1.0,)
(1.0,)
(1.0,)
(1.0,)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels