Skip to content

Commit 181748f

Browse files
m-bossartChrisRackauckas
authored andcommitted
wip: rules are hit
1 parent 7d3220c commit 181748f

File tree

1 file changed

+26
-18
lines changed

1 file changed

+26
-18
lines changed

ext/DiffEqBaseEnzymeExt.jl

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,22 +53,31 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}
5353
return ntuple(_ -> nothing, Val(length(args) + 4))
5454
end
5555

56-
function EnzymeRules.forward(func::Const{typeof(fastpow)},
57-
RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated,
58-
BatchDuplicated,BatchDuplicatedNoNeed}},
59-
_x::Annotation, _y::Annotation)
56+
function Enzyme.EnzymeRules.forward(func::Const{typeof(DiffEqBase.fastpow)},
57+
RT::Type{<:Union{Duplicated, DuplicatedNoNeed}},
58+
_x::Union{Const, Duplicated}, _y::Union{Const, Duplicated})
6059
x = _x.val
6160
y = _y.val
62-
ret = func.val(x.val, y.val)
63-
dxval = x.dval * y * (fastpow(x,y - 1))
64-
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x)
65-
return Duplicated(ret, dxval + dyval)
61+
ret = func.val(x, y)
62+
if !(_x isa Const)
63+
dxval = _x.dval * y * (fastpow(x,y - 1))
64+
else
65+
dxval = make_zero(_x.val)
66+
end
67+
if !(_y isa Const)
68+
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : _y.dval*(fastpow(x,y))*log(x)
69+
else
70+
dyval = make_zero(_y.val)
71+
end
72+
if RT <: DuplicatedNoNeed
73+
return Float32(dxval + dyval)
74+
else
75+
return Duplicated(ret, Float32(dxval + dyval))
76+
end
6677
end
6778

68-
function EnzymeRules.augmented_primal(config::ConfigWidth{1},
69-
func::Const{typeof(fastpow)},
70-
::Type{<:Active},
71-
x::Active, x::Active)
79+
function EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWidth{1},
80+
func::Const{typeof(fastpow)}, ::Type{<:Active}, x::Active, y::Active)
7281
if EnzymeRules.needs_primal(config)
7382
primal = func.val(x.val, y.val)
7483
else
@@ -77,14 +86,13 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1},
7786
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
7887
end
7988

80-
function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1},
81-
func::Const{DiffEqBase.fastpow}, dret, tape::Nothing,
82-
_x, _y)
89+
function EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1},
90+
func::Const{typeof(DiffEqBase.fastpow)}, dret::Active, tape, _x::Active, _y::Active)
8391
x = _x.val
8492
y = _y.val
85-
dxval = x.dval * y * (fastpow(x,y - 1))
86-
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x)
93+
dxval = y * (fastpow(x,y - 1))
94+
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : (fastpow(x,y))*log(x)
8795
return (dxval, dyval)
8896
end
8997

90-
end
98+
end

0 commit comments

Comments
 (0)