Skip to content

Commit 389abd9

Browse files
wsmosesavik-palgithub-actions[bot]
authored
Add strong zero (#1351)
* Add strong zero * Update api.md * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix --------- Co-authored-by: Avik Pal <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent e5e4fea commit 389abd9

File tree

4 files changed

+75
-0
lines changed

4 files changed

+75
-0
lines changed

docs/src/api/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,9 @@ Reactant.addressable_devices
5252
```@docs
5353
ReactantCore.materialize_traced_array
5454
```
55+
56+
## Differentiation Specific
57+
58+
```@docs
59+
Reactant.@strongzero
60+
```

src/Enzyme.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,52 @@ const enzyme_dupnoneed = 3
55
const enzyme_outnoneed = 4
66
const enzyme_constnoneed = 5
77

8+
function activate_strongzero!(strongzero::Bool)
9+
stack = get!(task_local_storage(), :reactant_strongzero) do
10+
Bool[]
11+
end
12+
push!(stack, strongzero)
13+
return nothing
14+
end
15+
16+
function deactivate_strongzero!(strongzero::Bool)
17+
key = :reactant_strongzero
18+
strongzero === last(task_local_storage(key)) ||
19+
error("Deactivating wrong strong zerocontext")
20+
return pop!(task_local_storage(key))
21+
end
22+
23+
function get_strongzero()
24+
key = :reactant_strongzero
25+
if !(haskey(task_local_storage(), key) && !Base.isempty(task_local_storage(key)))
26+
return false
27+
end
28+
return last(task_local_storage(key)::Vector{Bool})
29+
end
30+
31+
"""
32+
@strongzero() begin
33+
# Derivative calls that require Enzyme to use string zeroing
34+
end
35+
36+
Whether to enforce multiplication by zero as enforcing a zero result even if multiplying
37+
against a NaN or infinity. Necessary for some programs in which a value has a zero
38+
derivative since it is unused, even if it has an otherwise infinite or nan derivative.
39+
40+
Outside of reactant this is equivalent to setting the global flag Enzyme.API.strong_zero!(true)
41+
before differentiation. This should be moved into the mode in both cases.
42+
"""
43+
macro strongzero(ex)
44+
quote
45+
activate_strongzero!(true)
46+
try
47+
$(esc(ex))
48+
finally
49+
deactivate_strongzero!(true)
50+
end
51+
end
52+
end
53+
854
function Enzyme.make_zero(
955
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
1056
)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}}
@@ -394,6 +440,7 @@ function overload_autodiff(
394440
outputs=outtys,
395441
fn=fname,
396442
width,
443+
strong_zero=get_strongzero(),
397444
activity=MLIR.IR.Attribute([act_attr(a) for a in activity]),
398445
ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]),
399446
)

src/mlir/Dialects/Enzyme.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ function autodiff(
4444
activity,
4545
ret_activity,
4646
width=nothing,
47+
strong_zero=nothing,
4748
location=Location(),
4849
)
4950
op_ty_results = IR.Type[outputs...,]
@@ -56,6 +57,7 @@ function autodiff(
5657
namedattribute("ret_activity", ret_activity),
5758
]
5859
!isnothing(width) && push!(attributes, namedattribute("width", width))
60+
!isnothing(strong_zero) && push!(attributes, namedattribute("strong_zero", strong_zero))
5961

6062
return create_operation(
6163
"enzyme.autodiff",
@@ -126,6 +128,7 @@ function fwddiff(
126128
activity,
127129
ret_activity,
128130
width=nothing,
131+
strong_zero=nothing,
129132
location=Location(),
130133
)
131134
op_ty_results = IR.Type[outputs...,]
@@ -138,6 +141,7 @@ function fwddiff(
138141
namedattribute("ret_activity", ret_activity),
139142
]
140143
!isnothing(width) && push!(attributes, namedattribute("width", width))
144+
!isnothing(strong_zero) && push!(attributes, namedattribute("strong_zero", strong_zero))
141145

142146
return create_operation(
143147
"enzyme.fwddiff",

test/autodiff.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,21 @@ end
255255
contains(repr(hlo), "stablehlo.rng_bit_generator")
256256
end
257257
end
258+
259+
function divinf(x)
260+
return min(1.0, 1 / x)
261+
end
262+
263+
function grad_divinf(x)
264+
return Enzyme.gradient(Reverse, divinf, x)
265+
end
266+
267+
function grad_divinf_sz(x)
268+
Reactant.@strongzero Enzyme.gradient(Reverse, divinf, x)
269+
end
270+
271+
@testset "Strong zero" begin
272+
x = ConcreteRNumber(0.0)
273+
@test isnan((@jit grad_divinf(x))[1])
274+
@test iszero((@jit grad_divinf_sz(x))[1])
275+
end

0 commit comments

Comments
 (0)