Skip to content

Commit 707cba8

Browse files
Strong zero fixup (#1359)
* Strong zero fixup * Update test/autodiff.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 925e789 commit 707cba8

File tree

4 files changed

+9
-65
lines changed

4 files changed

+9
-65
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ CEnum = "0.5"
6969
CUDA = "5.6"
7070
Downloads = "1.6"
7171
EnumX = "1"
72-
Enzyme = "0.13.46"
73-
EnzymeCore = "0.8.9"
72+
Enzyme = "0.13.47"
73+
EnzymeCore = "0.8.11"
7474
Functors = "0.5"
7575
GPUArraysCore = "0.2"
7676
GPUCompiler = "1.3"

docs/src/api/api.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,3 @@ Reactant.addressable_devices
5252
```@docs
5353
ReactantCore.materialize_traced_array
5454
```
55-
56-
## Differentiation Specific
57-
58-
```@docs
59-
Reactant.@strongzero
60-
```

src/Enzyme.jl

Lines changed: 6 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,52 +5,6 @@ const enzyme_dupnoneed = 3
55
const enzyme_outnoneed = 4
66
const enzyme_constnoneed = 5
77

8-
function activate_strongzero!(strongzero::Bool)
9-
stack = get!(task_local_storage(), :reactant_strongzero) do
10-
Bool[]
11-
end
12-
push!(stack, strongzero)
13-
return nothing
14-
end
15-
16-
function deactivate_strongzero!(strongzero::Bool)
17-
key = :reactant_strongzero
18-
strongzero === last(task_local_storage(key)) ||
19-
error("Deactivating wrong strong zerocontext")
20-
return pop!(task_local_storage(key))
21-
end
22-
23-
function get_strongzero()
24-
key = :reactant_strongzero
25-
if !(haskey(task_local_storage(), key) && !Base.isempty(task_local_storage(key)))
26-
return false
27-
end
28-
return last(task_local_storage(key)::Vector{Bool})
29-
end
30-
31-
"""
32-
@strongzero() begin
33-
# Derivative calls that require Enzyme to use string zeroing
34-
end
35-
36-
Whether to enforce multiplication by zero as enforcing a zero result even if multiplying
37-
against a NaN or infinity. Necessary for some programs in which a value has a zero
38-
derivative since it is unused, even if it has an otherwise infinite or nan derivative.
39-
40-
Outside of reactant this is equivalent to setting the global flag Enzyme.API.strong_zero!(true)
41-
before differentiation. This should be moved into the mode in both cases.
42-
"""
43-
macro strongzero(ex)
44-
quote
45-
activate_strongzero!(true)
46-
try
47-
$(esc(ex))
48-
finally
49-
deactivate_strongzero!(true)
50-
end
51-
end
52-
end
53-
548
function Enzyme.make_zero(
559
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
5610
)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}}
@@ -338,13 +292,9 @@ function overload_autodiff(
338292
end
339293

340294
outtys = MLIR.IR.Type[]
341-
@inline needs_primal(::Type{<:Enzyme.ReverseMode{ReturnPrimal}}) where {ReturnPrimal} =
342-
ReturnPrimal
343-
@inline needs_primal(::Type{<:Enzyme.ForwardMode{ReturnPrimal}}) where {ReturnPrimal} =
344-
ReturnPrimal
345295
for a in linear_results
346296
if TracedUtils.has_idx(a, resprefix)
347-
if needs_primal(CMode)
297+
if Enzyme.needs_primal(CMode)
348298
push!(
349299
outtys,
350300
TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))),
@@ -389,7 +339,7 @@ function overload_autodiff(
389339
ret_activity = Int32[]
390340
for a in linear_results
391341
if TracedUtils.has_idx(a, resprefix)
392-
act = act_from_type(A, reverse, needs_primal(CMode))
342+
act = act_from_type(A, reverse, Enzyme.needs_primal(CMode))
393343
push!(ret_activity, act)
394344
if act == enzyme_out || act == enzyme_outnoneed
395345
attr = MLIR.IR.DenseElementsAttribute(
@@ -440,7 +390,7 @@ function overload_autodiff(
440390
outputs=outtys,
441391
fn=fname,
442392
width,
443-
strong_zero=get_strongzero(),
393+
strong_zero=Enzyme.strong_zero(CMode),
444394
activity=MLIR.IR.Attribute([act_attr(a) for a in activity]),
445395
ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]),
446396
)
@@ -462,7 +412,7 @@ function overload_autodiff(
462412

463413
for a in linear_results
464414
if TracedUtils.has_idx(a, resprefix)
465-
if needs_primal(CMode)
415+
if Enzyme.needs_primal(CMode)
466416
path = TracedUtils.get_idx(a, resprefix)
467417
tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx))
468418
TracedUtils.set!(result, path[2:end], tval)
@@ -567,14 +517,14 @@ function overload_autodiff(
567517
func2.operation = MLIR.API.MlirOperation(C_NULL)
568518

569519
if reverse
570-
resv = if needs_primal(CMode)
520+
resv = if Enzyme.needs_primal(CMode)
571521
result
572522
else
573523
nothing
574524
end
575525
return ((restup...,), resv)
576526
else
577-
if needs_primal(CMode)
527+
if Enzyme.needs_primal(CMode)
578528
if CMode <: Enzyme.ForwardMode && !(A <: Enzyme.Const)
579529
(dresult, result)
580530
else

test/autodiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ function grad_divinf(x)
265265
end
266266

267267
function grad_divinf_sz(x)
268-
Reactant.@strongzero Enzyme.gradient(Reverse, divinf, x)
268+
return Enzyme.gradient(Enzyme.set_strong_zero(Reverse), divinf, x)
269269
end
270270

271271
@testset "Strong zero" begin

0 commit comments

Comments
 (0)