Skip to content

Commit 56c07f6

Browse files
authored
fix: ignore_derivatives need to preserve structure (#1379)
1 parent f2ad3ef commit 56c07f6

File tree

3 files changed

+27
-8
lines changed

3 files changed

+27
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.128"
4+
version = "0.2.129"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/Enzyme.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -549,12 +549,15 @@ be applied on a nested structure of arrays and we will apply the operation on ea
549549
leaves.
550550
"""
551551
function ignore_derivatives(args...)
552-
return map(args) do arg
553-
return Functors.fmap(arg) do argᵢ
554-
if argᵢ isa TracedType || argᵢ isa AnyTracedRArray
555-
return Ops.ignore_derivatives(materialize_traced_array(argᵢ))
556-
end
557-
return argᵢ
558-
end
552+
res = map(ignore_derivatives_internal, args)
553+
length(args) == 1 && return only(res)
554+
return res
555+
end
556+
557+
function ignore_derivatives_internal(arg)
558+
return Functors.fmap(arg) do argᵢ
559+
argᵢ isa AnyTracedRArray && (argᵢ = materialize_traced_array(argᵢ))
560+
argᵢ isa TracedType && return Ops.ignore_derivatives(argᵢ)
561+
return argᵢ
559562
end
560563
end

test/autodiff.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,14 @@ function simple_grad_with_ignore(x::AbstractArray{T}) where {T}
282282
return Reactant.ignore_derivatives(sum(x; dims=1), x .- 1, (x, x .+ 2)), sum(abs2, x)
283283
end
284284

285+
function zero_grad(x)
286+
return Reactant.ignore_derivatives(sum(x))
287+
end
288+
289+
function zero_grad2(x)
290+
return Reactant.ignore_derivatives(sum(x), x)
291+
end
292+
285293
@testset "ignore_derivatives" begin
286294
x = Reactant.to_rarray(rand(Float32, 4, 4))
287295

@@ -290,4 +298,12 @@ end
290298

291299
res2 = @jit Enzyme.gradient(Reverse, simple_grad_with_ignore, x)
292300
@test res2[1] (2 .* Array(x))
301+
302+
∂x, result = @jit Enzyme.gradient(ReverseWithPrimal, zero_grad, x)
303+
@test result isa ConcreteRNumber{Float32}
304+
@test ∂x[1] zeros(Float32, 4, 4)
305+
306+
∂x2, result2 = @jit Enzyme.gradient(ReverseWithPrimal, zero_grad2, x)
307+
@test result2 isa Tuple{<:ConcreteRNumber{Float32},<:ConcreteRArray{Float32,2}}
308+
@test ∂x2[1] zeros(Float32, 4, 4)
293309
end

0 commit comments

Comments
 (0)