Skip to content

Commit 73abcf1

Browse files
Pushforward dvals in the ranges forward rule (#1663)
* Pushforward dvals in the ranges forward rule I just realized before it merged that we're actually missing a part of the pushforward here. Even though the solution to the derivative is just the range itself, we forgot to multiply it by the dval to pushforward the derivative. It implicitly had it as one, calculating the derivative, instead of the full pushforward. The test should be updated to catch this. * Test dual propagation Changes the inputs to be the same but have another operation * fix and test batch rules and format
1 parent 7744a8c commit 73abcf1

File tree

2 files changed

+47
-26
lines changed

2 files changed

+47
-26
lines changed

src/internal_rules.jl

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -820,38 +820,46 @@ end
820820
# Float64 ranges in Julia use bitwise `&` with higher precision
821821
# to correct for numerical error, thus we put rules over the
822822
# operations as this is not directly differentiable
823-
function EnzymeRules.forward(func::Const{Colon}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, start::Annotation, step::Annotation, stop::Annotation)
823+
function EnzymeRules.forward(func::Const{Colon},
824+
RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated,
825+
BatchDuplicated,BatchDuplicatedNoNeed}},
826+
start::Annotation, step::Annotation, stop::Annotation)
824827
ret = func.val(start.val, step.val, stop.val)
825-
dstart = if start isa Const
826-
zero(eltype(ret))
828+
dstart = if start isa Const
829+
zero(eltype(ret))
827830
elseif start isa Duplicated || start isa DuplicatedNoNeed
828-
one(eltype(ret))
831+
start.dval
829832
elseif start isa BatchDuplicated || start isa BatchDuplicatedNoNeed
830-
ntuple(x->one(eltype(ret)), Val(width(RT)))
833+
ntuple(i -> start.dval[i], Val(width(RT)))
831834
else
832835
error("Annotation type $(typeof(start)) not supported for range start. Please open an issue")
833836
end
834837

835-
dstep = if step isa Const
836-
zero(eltype(ret))
838+
dstep = if step isa Const
839+
zero(eltype(ret))
837840
elseif step isa Duplicated || step isa DuplicatedNoNeed
838-
one(eltype(ret))
841+
step.dval
839842
elseif step isa BatchDuplicated || step isa BatchDuplicatedNoNeed
840-
ntuple(x->one(eltype(ret)), Val(width(RT)))
843+
ntuple(i -> step.dval[i], Val(width(RT)))
841844
else
842845
error("Annotation type $(typeof(start)) not supported for range step. Please open an issue")
843846
end
844847

845-
if RT <: Duplicated
846-
Duplicated(ret, range(dstart, step=dstep, length=length(ret)))
848+
if RT <: Duplicated
849+
Duplicated(ret, range(dstart; step=dstep, length=length(ret)))
847850
elseif RT <: Const
848851
ret
849852
elseif RT <: DuplicatedNoNeed
850-
range(dstart, step=dstep, length=length(ret))
853+
range(dstart; step=dstep, length=length(ret))
851854
elseif RT <: BatchDuplicated
852-
BatchDuplicated(ret, ntuple(x-> range(dstart, step=dstep, length=length(ret)), Val(width(RT))))
855+
BatchDuplicated(ret,
856+
ntuple(i -> range(dstart isa Number ? dstart : dstart[i];
857+
step=dstep isa Number ? dstep : dstep[i],
858+
length=length(ret)), Val(width(RT))))
853859
elseif RT <: BatchDuplicatedNoNeed
854-
ntuple(x-> range(dstart, step=dstep, length=length(ret)), Val(width(RT)))
860+
ntuple(i -> range(dstart isa Number ? dstart : dstart[i];
861+
step=dstep isa Number ? dstep : dstep[i],
862+
length=length(ret)), Val(width(RT)))
855863
else
856864
error("This should not be possible. Please report.")
857865
end

test/internal_rules.jl

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -620,25 +620,38 @@ end
620620

621621
@testset "Ranges" begin
622622
function f1(x)
623+
x = 25.0x
623624
ts = Array(0.0:x:3.0)
624-
sum(ts)
625+
return sum(ts)
625626
end
626627
function f2(x)
627-
ts = Array(0.0:.25:3.0)
628-
sum(ts) + x
628+
x = 25.0x
629+
ts = Array(0.0:0.25:3.0)
630+
return sum(ts) + x
629631
end
630632
function f3(x)
631-
ts = Array(x:.25:3.0)
632-
sum(ts)
633+
x = 25.0x
634+
ts = Array(x:0.25:3.0)
635+
return sum(ts)
633636
end
634637
function f4(x)
635-
ts = Array(0.0:.25:x)
636-
sum(ts)
637-
end
638-
@test Enzyme.autodiff(Forward, f1, Duplicated(0.25, 1.0)) == (78,)
639-
@test Enzyme.autodiff(Forward, f2, Duplicated(0.25, 1.0)) == (1.0,)
640-
@test Enzyme.autodiff(Forward, f3, Duplicated(0.25, 1.0)) == (12,)
641-
@test Enzyme.autodiff(Forward, f4, Duplicated(3.0, 1.0)) == (0,)
638+
x = 25.0x
639+
ts = Array(0.0:0.25:x)
640+
return sum(ts)
641+
end
642+
@test Enzyme.autodiff(Forward, f1, Duplicated(0.1, 1.0)) == (25.0,)
643+
@test Enzyme.autodiff(Forward, f2, Duplicated(0.1, 1.0)) == (25.0,)
644+
@test Enzyme.autodiff(Forward, f3, Duplicated(0.1, 1.0)) == (75.0,)
645+
@test Enzyme.autodiff(Forward, f4, Duplicated(0.12, 1.0)) == (0,)
646+
647+
@test Enzyme.autodiff(Forward, f1, BatchDuplicated(0.1, (1.0, 2.0))) ==
648+
((var"1"=25.0, var"2"=50.0),)
649+
@test Enzyme.autodiff(Forward, f2, BatchDuplicated(0.1, (1.0, 2.0))) ==
650+
((var"1"=25.0, var"2"=50.0),)
651+
@test Enzyme.autodiff(Forward, f3, BatchDuplicated(0.1, (1.0, 2.0))) ==
652+
((var"1"=75.0, var"2"=150.0),)
653+
@test Enzyme.autodiff(Forward, f4, BatchDuplicated(0.12, (1.0, 2.0))) ==
654+
((var"1"=0.0, var"2"=0.0),)
642655
end
643656

644657
end # InternalRules

0 commit comments

Comments
 (0)