Skip to content

Commit 3563bbd

Browse files
authored
Add missing colon rule (#2641)
1 parent aaa5a36 commit 3563bbd

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

src/internal_rules.jl

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1380,7 +1380,7 @@ function EnzymeRules.augmented_primal(
13801380
start::Annotation{<:AbstractFloat},
13811381
step::Annotation{<:AbstractFloat},
13821382
stop::Annotation{<:AbstractFloat},
1383-
) where RT <: Active
1383+
) where RT <: Union{Active,Const}
13841384

13851385
if EnzymeRules.needs_primal(config)
13861386
primal = func.val(start.val, step.val, stop.val)
@@ -1394,6 +1394,52 @@ function EnzymeRules.augmented_primal(
13941394
}(primal, nothing, nothing)
13951395
end
13961396

1397+
function EnzymeRules.reverse(
1398+
config::EnzymeRules.RevConfig,
1399+
func::Const{Colon},
1400+
dret::Const,
1401+
tape::Nothing,
1402+
start::Annotation{T1},
1403+
step::Annotation{T2},
1404+
stop::Annotation{T3},
1405+
) where {T1<:AbstractFloat,T2<:AbstractFloat,T3<:AbstractFloat}
1406+
dstart = if start isa Const
1407+
nothing
1408+
elseif EnzymeRules.width(config) == 1
1409+
zero(T1)
1410+
else
1411+
ntuple(Val(EnzymeRules.width(config))) do i
1412+
Base.@_inline_meta
1413+
zero(T1)
1414+
end
1415+
end
1416+
1417+
dstep = if step isa Const
1418+
nothing
1419+
elseif EnzymeRules.width(config) == 1
1420+
zero(T2)
1421+
else
1422+
ntuple(Val(EnzymeRules.width(config))) do i
1423+
Base.@_inline_meta
1424+
zero(T2)
1425+
end
1426+
end
1427+
1428+
dstop = if stop isa Const
1429+
nothing
1430+
elseif EnzymeRules.width(config) == 1
1431+
zero(T3)
1432+
else
1433+
ntuple(Val(EnzymeRules.width(config))) do i
1434+
Base.@_inline_meta
1435+
zero(T3)
1436+
end
1437+
end
1438+
1439+
return (dstart, dstep, dstop)
1440+
end
1441+
1442+
13971443
function EnzymeRules.reverse(
13981444
config::EnzymeRules.RevConfig,
13991445
func::Const{Colon},

0 commit comments

Comments
 (0)