Skip to content

Commit 7a24937

Browse files
authored
Add non scalar return error (#2404)
* Add non scalar return error * fix * Update errors.jl
1 parent 2c2fb62 commit 7a24937

File tree

6 files changed

+29
-13
lines changed

6 files changed

+29
-13
lines changed

src/compiler.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4986,10 +4986,7 @@ function add_one_in_place(x)
49864986
elseif x isa (Array{T,0} where T)
49874987
x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x))))
49884988
else
4989-
error(
4990-
"Enzyme Mutability Error: Cannot add one in place to immutable value " *
4991-
string(x),
4992-
)
4989+
throw(EnzymeNonScalarReturnException(x, ""))
49934990
end
49944991
return nothing
49954992
end

src/errors.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,3 +826,20 @@ end
826826
end
827827
throw(AssertionError("Unknown errtype"))
828828
end
829+
830+
struct EnzymeNonScalarReturnException <: EnzymeError
831+
object
832+
extra::String
833+
end
834+
835+
function Base.showerror(io::IO, ece::EnzymeNonScalarReturnException)
836+
if isdefined(Base.Experimental, :show_error_hints)
837+
Base.Experimental.show_error_hints(io, ece)
838+
end
839+
println(io, "Return type of differentiated function was not a scalar as required, found ", ece.object)
840+
println(io, "If calling Enzyme.autodiff(Reverse, f, Active, ...), try Enzyme.autodiff_thunk(Reverse, f, Duplicated, ....)")
841+
println(io, "If calling Enzyme.gradient, try Enzyme.jacobian")
842+
if length(ece.extra) != 0
843+
print(io, ece.extra)
844+
end
845+
end

src/rules/jitrules.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -606,18 +606,15 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act
606606
elseif $shad isa Base.RefValue
607607
$shad[] = recursive_add($shad[], $expr)
608608
else
609-
error(
610-
"Enzyme Mutability Error: Cannot add one in place to immutable value " *
611-
string($shad) *
612-
" tup[i]=" *
609+
throw(EnzymeNonScalarReturnException($shad, " tup[i]=" *
613610
string(tup[$i]) *
614611
" i=" *
615612
string($i) *
616613
" w=" *
617614
string($w) *
618615
" tup=" *
619616
string(tup),
620-
)
617+
))
621618
end
622619
end
623620
@inbounds outs[(i-1)*Width+w] = out

src/rules/typeunstablerules.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,7 @@ function body_construct_rev(
124124
if $shad isa Base.RefValue
125125
$shad[] = recursive_add($shad[], $expr, identity, guaranteed_nonactive)
126126
else
127-
error(
128-
"Enzyme Mutability Error: Cannot add one in place to immutable value " *
129-
string($shad),
130-
)
127+
throw(EnzymeNonScalarReturnException($shad, ""))
131128
end
132129
end
133130
)

test/errors.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
using Enzyme, Test
2+
3+
array_square(x) = 2 .* x
4+
5+
@testset "Array of Pointer Copy" begin
6+
@test_throws Enzyme.Compiler.EnzymeNonScalarReturnException Enzyme.gradient(Reverse, array_square, [2.0])
7+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2968,6 +2968,7 @@ end
29682968
end
29692969

29702970
include("sugar.jl")
2971+
include("errors.jl")
29712972

29722973
@testset "Forward on Reverse" begin
29732974

0 commit comments

Comments
 (0)