Skip to content

Commit df9bff9

Browse files
authored
Make zero in place (#1518)
* Make zero in place * add make_zero! * more fixes and tests
1 parent ad7694e commit df9bff9

File tree

8 files changed

+239
-12
lines changed

8 files changed

+239
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1919
[compat]
2020
CEnum = "0.4, 0.5"
2121
ChainRulesCore = "1"
22-
EnzymeCore = "0.7.3"
22+
EnzymeCore = "0.7.4"
2323
Enzyme_jll = "0.0.119"
2424
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26"
2525
LLVM = "6.1, 7"

examples/custom_rule.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ function forward(func::Const{typeof(f)}, RT::Type{<:Union{Const, DuplicatedNoNee
134134
if !(x isa Const) && !(y isa Const)
135135
y.dval .= 2 .* x.val .* x.dval
136136
elseif !(y isa Const)
137-
y.dval .= 0
137+
make_zero!(y.dval)
138138
end
139139
dret = !(y isa Const) ? sum(y.dval) : zero(eltype(y.val))
140140
if RT <: Const
@@ -211,7 +211,7 @@ function reverse(config::ConfigWidth{1}, func::Const{typeof(f)}, dret::Active, t
211211
x.dval .+= 2 .* xval .* dret.val
212212
## also accumulate any derivative in y's shadow into x's shadow.
213213
x.dval .+= 2 .* xval .* y.dval
214-
y.dval .= 0
214+
make_zero!(y.dval)
215215
return (nothing, nothing)
216216
end
217217

@@ -251,8 +251,8 @@ end
251251

252252
x = [3.0, 1.0]
253253
y = [0.0, 0.0]
254-
dx .= 0
255-
dy .= 0
254+
make_zero!(dx)
255+
make_zero!(dy)
256256

257257
autodiff(Reverse, h, Duplicated(y, dy), Duplicated(x, dx))
258258
@show dx # derivative of h w.r.t. x

lib/EnzymeCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EnzymeCore"
22
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
4-
version = "0.7.3"
4+
version = "0.7.4"
55

66
[compat]
77
Adapt = "3, 4"

lib/EnzymeCore/src/EnzymeCore.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,13 @@ function autodiff_deferred_thunk end
228228
"""
229229
function make_zero end
230230

231+
"""
232+
make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing
233+
234+
Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`.
235+
"""
236+
function make_zero! end
237+
231238
"""
232239
make_zero(prev::T)
233240

src/Enzyme.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ export BatchDuplicatedFunc
1414
import EnzymeCore: batch_size, get_func
1515
export batch_size, get_func
1616

17-
import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero
18-
export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero
17+
import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero!
18+
export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero!
1919

2020
export jacobian, gradient, gradient!
2121
export markType, batch_size, onehot, chunkedonehot
@@ -1007,7 +1007,7 @@ gradient!(Reverse, dx, f, [2.0, 3.0])
10071007
```
10081008
"""
10091009
@inline function gradient!(::ReverseMode, dx::X, f::F, x::X) where {X<:Array, F}
1010-
dx .= 0
1010+
make_zero!(dx)
10111011
autodiff(Reverse, f, Active, Duplicated(x, dx))
10121012
dx
10131013
end

src/api.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ struct CFnTypeInfo
104104
end
105105

106106

107-
@static if isdefined(LLVM, :InstructionMetadataDict)
107+
@static if !isdefined(LLVM, :ValueMetadataDict)
108108
Base.haskey(md::LLVM.InstructionMetadataDict, kind::String) =
109109
ccall((:EnzymeGetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring), md.inst, kind) != C_NULL
110110

src/compiler.jl

Lines changed: 200 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,7 +1298,7 @@ end
12981298
xi = getfield(prev, i)
12991299
T = Core.Typeof(xi)
13001300
xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive))
1301-
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi)
1301+
setfield!(y, i, xi)
13021302
end
13031303
end
13041304
return y
@@ -1324,6 +1324,204 @@ end
13241324
return y
13251325
end
13261326

1327+
function make_zero_immutable!(prev::T, seen::S)::T where {T <: AbstractFloat, S}
1328+
zero(T)
1329+
end
1330+
1331+
function make_zero_immutable!(prev::Complex{T}, seen::S)::Complex{T} where {T <: AbstractFloat, S}
1332+
zero(T)
1333+
end
1334+
1335+
function make_zero_immutable!(prev::T, seen::S)::T where {T <: Tuple, S}
1336+
ntuple(Val(length(T.parameters))) do i
1337+
Base.@_inline_meta
1338+
make_zero_immutable!(prev[i], seen)
1339+
end
1340+
end
1341+
1342+
function make_zero_immutable!(prev::NamedTuple{a, b}, seen::S)::NamedTuple{a, b} where {a,b, S}
1343+
NamedTuple{a, b}(
1344+
ntuple(Val(length(T.parameters))) do i
1345+
Base.@_inline_meta
1346+
make_zero_immutable!(prev[a[i]], seen)
1347+
end
1348+
)
1349+
end
1350+
1351+
1352+
function make_zero_immutable!(prev::T, seen::S)::T where {T, S}
1353+
if guaranteed_const_nongen(T, nothing)
1354+
return prev
1355+
end
1356+
@assert !ismutable(T)
1357+
1358+
@assert !Base.isabstracttype(RT)
1359+
@assert Base.isconcretetype(RT)
1360+
nf = fieldcount(RT)
1361+
1362+
flds = Vector{Any}(undef, nf)
1363+
for i in 1:nf
1364+
if isdefined(prev, i)
1365+
xi = getfield(prev, i)
1366+
ST = Core.Typeof(xi)
1367+
flds[i] = if active_reg_inner(ST, (), nothing, #=justActive=#Val(true)) == ActiveState
1368+
make_zero_immutable!(xi, seen)
1369+
else
1370+
EnzymeCore.make_zero!(xi, seen)
1371+
xi
1372+
end
1373+
else
1374+
nf = i - 1 # rest of tail must be undefined values
1375+
break
1376+
end
1377+
end
1378+
ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T
1379+
end
1380+
1381+
@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T}, seen::ST)::Nothing where {T <: AbstractFloat, ST}
1382+
T[] = zero(T)
1383+
nothing
1384+
end
1385+
1386+
@inline function EnzymeCore.make_zero!(prev::Base.RefValue{Complex{T}}, seen::ST)::Nothing where {T <: AbstractFloat, ST}
1387+
T[] = zero(Complex{T})
1388+
nothing
1389+
end
1390+
1391+
@inline function EnzymeCore.make_zero!(prev::Array{T, N}, seen::ST)::Nothing where {T <: AbstractFloat, N, ST}
1392+
fill!(prev, zero(T))
1393+
nothing
1394+
end
1395+
1396+
@inline function EnzymeCore.make_zero!(prev::Array{Complex{T}, N}, seen::ST)::Nothing where {T <: AbstractFloat, N, ST}
1397+
fill!(prev, zero(Complex{T}))
1398+
nothing
1399+
end
1400+
1401+
@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T})::Nothing where {T <: AbstractFloat}
1402+
EnzymeCore.make_zero!(prev, nothing)
1403+
nothing
1404+
end
1405+
1406+
@inline function EnzymeCore.make_zero!(prev::Base.RefValue{Complex{T}})::Nothing where {T <: AbstractFloat}
1407+
EnzymeCore.make_zero!(prev, nothing)
1408+
nothing
1409+
end
1410+
1411+
@inline function EnzymeCore.make_zero!(prev::Array{T, N})::Nothing where {T <: AbstractFloat, N}
1412+
EnzymeCore.make_zero!(prev, nothing)
1413+
nothing
1414+
end
1415+
1416+
@inline function EnzymeCore.make_zero!(prev::Array{Complex{T}, N})::Nothing where {T <: AbstractFloat, N}
1417+
EnzymeCore.make_zero!(prev, nothing)
1418+
nothing
1419+
end
1420+
1421+
@inline function EnzymeCore.make_zero!(prev::Array{T, N}, seen::ST)::Nothing where {T, N, ST}
1422+
if guaranteed_const_nongen(T, nothing)
1423+
return
1424+
end
1425+
if in(seen, prev)
1426+
return
1427+
end
1428+
push!(seen, prev)
1429+
1430+
for I in eachindex(prev)
1431+
if isassigned(prev, I)
1432+
pv = prev[I]
1433+
SBT = Core.Typeof(pv)
1434+
if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState
1435+
@inbounds prev[I] = make_zero_immutable!(pv, seen)
1436+
nothing
1437+
else
1438+
EnzymeCore.make_zero!(pv, seen)
1439+
nothing
1440+
end
1441+
end
1442+
end
1443+
nothing
1444+
end
1445+
1446+
@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T}, seen::ST)::Nothing where {T, ST}
1447+
if guaranteed_const_nongen(T, nothing)
1448+
return
1449+
end
1450+
if in(seen, prev)
1451+
return
1452+
end
1453+
push!(seen, prev)
1454+
1455+
pv = prev[]
1456+
SBT = Core.Typeof(pv)
1457+
if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState
1458+
prev[] = make_zero_immutable!(pv, seen)
1459+
nothing
1460+
else
1461+
EnzymeCore.make_zero!(pv, seen)
1462+
nothing
1463+
end
1464+
nothing
1465+
end
1466+
1467+
@inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST}
1468+
pv = prev.contents
1469+
T = Core.Typeof(pv)
1470+
if guaranteed_const_nongen(T, nothing)
1471+
return
1472+
end
1473+
if in(seen, prev)
1474+
return
1475+
end
1476+
push!(seen, prev)
1477+
SBT = Core.Typeof(pv)
1478+
if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState
1479+
prev.contents = EnzymeCore.make_zero_immutable!(pv, seen)
1480+
nothing
1481+
else
1482+
EnzymeCore.make_zero!(pv, seen)
1483+
nothing
1484+
end
1485+
nothing
1486+
end
1487+
1488+
@inline function EnzymeCore.make_zero!(prev::T, seen::S=Base.IdSet{Any}())::Nothing where {T, S}
1489+
if guaranteed_const_nongen(T, nothing)
1490+
return
1491+
end
1492+
if in(seen, prev)
1493+
return
1494+
end
1495+
@assert !Base.isabstracttype(T)
1496+
@assert Base.isconcretetype(T)
1497+
nf = fieldcount(T)
1498+
1499+
1500+
if nf == 0
1501+
return
1502+
end
1503+
1504+
push!(seen, prev)
1505+
1506+
for i in 1:nf
1507+
if isdefined(prev, i)
1508+
xi = getfield(prev, i)
1509+
SBT = Core.Typeof(xi)
1510+
if guaranteed_const_nongen(SBT, nothing)
1511+
continue
1512+
end
1513+
if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState
1514+
setfield!(prev, i, make_zero_immutable!(xi, seen))
1515+
nothing
1516+
else
1517+
EnzymeCore.make_zero!(xi, seen)
1518+
nothing
1519+
end
1520+
end
1521+
end
1522+
return
1523+
end
1524+
13271525
struct EnzymeRuntimeException <: Base.Exception
13281526
msg::Cstring
13291527
end
@@ -5536,7 +5734,7 @@ end
55365734
@assert ismutable(x)
55375735
yi = getfield(y, i)
55385736
nexti = recursive_add(xi, yi, f, mutable_register)
5539-
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), x, i-1, nexti)
5737+
setfield!(x, i, nexti)
55405738
end
55415739
end
55425740
end

test/runtests.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,28 @@ end
181181
# @test thunk_split.primal !== C_NULL
182182
# @test thunk_split.primal !== thunk_split.adjoint
183183
# @test thunk_a.adjoint !== thunk_split.adjoint
184+
#
185+
z = ([3.14, 21.5, 16.7], [0,1], [5.6, 8.9])
186+
Enzyme.make_zero!(z)
187+
@test z[1] [0.0, 0.0, 0.0]
188+
@test z[2][1] == 0
189+
@test z[2][2] == 1
190+
@test z[3] [0.0, 0.0]
191+
192+
z2 = ([3.14, 21.5, 16.7], [0,1], [5.6, 8.9])
193+
Enzyme.make_zero!(z2)
194+
@test z2[1] [0.0, 0.0, 0.0]
195+
@test z2[2][1] == 0
196+
@test z2[2][2] == 1
197+
@test z2[3] [0.0, 0.0]
198+
199+
z3 = [3.4, "foo"]
200+
Enzyme.make_zero!(z3)
201+
@test z3[1] 0.0
202+
@test z3[2] == "foo"
203+
204+
z4 = sin
205+
Enzyme.make_zero!(z4)
184206
end
185207

186208
@testset "Reflection" begin

0 commit comments

Comments
 (0)