Skip to content

Commit 3aa6a5a

Browse files
Add internal forward-mode rules for ranges (#1655)
* Add internal forward-mode rules for ranges This is part 1 one solving #274. It does the forward mode rules as those are simpler. A separate PR will do the WIP reverse mode rules as that seems to be a bit more complex. Add missing `@test` don't forget the rule * namespace * Update internal_rules.jl * Update internal_rules.jl * Update src/internal_rules.jl * Update internal_rules.jl * Update internal_rules.jl --------- Co-authored-by: William Moses <[email protected]>
1 parent c0caf9a commit 3aa6a5a

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

src/internal_rules.jl

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,47 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)},
816816
end
817817
end
818818

819+
# Ranges
820+
# Float64 ranges in Julia use bitwise `&` with higher precision
821+
# to correct for numerical error, thus we put rules over the
822+
# operations as this is not directly differentiable
823+
function EnzymeRules.forward(func::Const{Colon}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, start::Annotation, step::Annotation, stop::Annotation)
824+
ret = func.val(start.val, step.val, stop.val)
825+
dstart = if start isa Const
826+
zero(eltype(ret))
827+
elseif start isa Duplicated || start isa DuplicatedNoNeed
828+
one(eltype(ret))
829+
elseif start isa BatchDuplicated || start isa BatchDuplicatedNoNeed
830+
ntuple(x->one(eltype(ret)), Val(width(RT)))
831+
else
832+
error("Annotation type $(typeof(start)) not supported for range start. Please open an issue")
833+
end
834+
835+
dstep = if step isa Const
836+
zero(eltype(ret))
837+
elseif step isa Duplicated || step isa DuplicatedNoNeed
838+
one(eltype(ret))
839+
elseif step isa BatchDuplicated || step isa BatchDuplicatedNoNeed
840+
ntuple(x->one(eltype(ret)), Val(width(RT)))
841+
else
842+
error("Annotation type $(typeof(start)) not supported for range step. Please open an issue")
843+
end
844+
845+
if RT <: Duplicated
846+
Duplicated(ret, range(dstart, step=dstep, length=length(ret)))
847+
elseif RT <: Const
848+
ret
849+
elseif RT <: DuplicatedNoNeed
850+
range(dstart, step=dstep, length=length(ret))
851+
elseif RT <: BatchDuplicated
852+
BatchDuplicated(ret, ntuple(x-> range(dstart, step=dstep, length=length(ret)), Val(width(RT))))
853+
elseif RT <: BatchDuplicatedNoNeed
854+
ntuple(x-> range(dstart, step=dstep, length=length(ret)), Val(width(RT)))
855+
else
856+
error("This should not be possible. Please report.")
857+
end
858+
end
859+
819860
function EnzymeRules.forward(
820861
Ty::Const{Type{BigFloat}},
821862
RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}};
@@ -876,4 +917,4 @@ function EnzymeRules.reverse(
876917
kwargs...,
877918
)
878919
return ()
879-
end
920+
end

test/internal_rules.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,4 +618,27 @@ end
618618
@test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == ((1.0,),)
619619
end
620620

621+
@testset "Ranges" begin
622+
function f1(x)
623+
ts = Array(0.0:x:3.0)
624+
sum(ts)
625+
end
626+
function f2(x)
627+
ts = Array(0.0:.25:3.0)
628+
sum(ts) + x
629+
end
630+
function f3(x)
631+
ts = Array(x:.25:3.0)
632+
sum(ts)
633+
end
634+
function f4(x)
635+
ts = Array(0.0:.25:x)
636+
sum(ts)
637+
end
638+
@test Enzyme.autodiff(Forward, f1, Duplicated(0.25, 1.0)) == (78,)
639+
@test Enzyme.autodiff(Forward, f2, Duplicated(0.25, 1.0)) == (1.0,)
640+
@test Enzyme.autodiff(Forward, f3, Duplicated(0.25, 1.0)) == (12,)
641+
@test Enzyme.autodiff(Forward, f4, Duplicated(3.0, 1.0)) == (0,)
642+
end
643+
621644
end # InternalRules

0 commit comments

Comments
 (0)