Skip to content

Commit 8a2fa2e

Browse files
authored
add support for ignore_derivatives (#2547)
1 parent 297c7c7 commit 8a2fa2e

File tree

8 files changed

+83
-50
lines changed

8 files changed

+83
-50
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Enzyme"
22
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
4-
version = "0.13.80"
4+
version = "0.13.81"
55

66
[deps]
77
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
@@ -44,7 +44,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5"
4444
CEnum = "0.4, 0.5"
4545
ChainRulesCore = "1"
4646
DynamicPPL = "0.35, 0.36, 0.37"
47-
EnzymeCore = "0.8.13"
47+
EnzymeCore = "0.8.14"
4848
Enzyme_jll = "0.0.201"
4949
GPUArraysCore = "0.1.6, 0.2"
5050
GPUCompiler = "1.6.2"

docs/src/notebooks/ignore_derivatives.jl

Lines changed: 4 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ begin
3131
)
3232
end
3333

34+
# ╔═╡ 23a8503f-3c68-4523-aebe-a4ce4575a02b
35+
import Enzyme: ignore_derivatives
36+
3437
# ╔═╡ df72e42f-7eec-476f-8ce5-72b09f620005
3538
md"""
3639
# Reproducing "Stabilizing backpropagation through time to learn complex physics"
@@ -116,51 +119,6 @@ end
116119
# ╔═╡ be852753-126d-42fa-a55c-c907f5dce99d
117120
plot_gradientfield(N, S, x₀, y, θ₁, θ₂, n)
118121

119-
# ╔═╡ 0b6d0456-1f94-479f-b690-89ad7fc61e44
120-
begin
121-
@noinline function ignore_derivatives(x::T) where {T}
122-
return Core.inferencebarrier(x)::T
123-
end
124-
125-
function EnzymeRules.forward(
126-
config,
127-
::Const{typeof(ignore_derivatives)},
128-
A, x::Duplicated
129-
)
130-
return Enzyme.make_zero(x.val)
131-
end
132-
133-
function EnzymeRules.augmented_primal(
134-
config,
135-
::Const{typeof(ignore_derivatives)},
136-
FA, x
137-
)
138-
primal = EnzymeRules.needs_primal(config) ? x.val : nothing
139-
if x isa Active
140-
shadow = nothing
141-
else
142-
shadow = Enzyme.make_zero(x.val)
143-
end
144-
145-
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
146-
end
147-
function EnzymeRules.reverse(
148-
config,
149-
::Const{typeof(ignore_derivatives)},
150-
dret::Active, tape, x::Active
151-
)
152-
return (Enzyme.make_zero(x.val),)
153-
end
154-
155-
function EnzymeRules.reverse(
156-
config,
157-
::Const{typeof(ignore_derivatives)},
158-
::Type{<:Duplicated}, tape, x::Duplicated
159-
)
160-
return (nothing,)
161-
end
162-
end
163-
164122
# ╔═╡ 873e7792-99a1-4472-92c2-6fc32e2889fa
165123
N_stop(xᵢ, θ) = θ[1] * ignore_derivatives(xᵢ^2) + θ[2] * ignore_derivatives(xᵢ)
166124

@@ -170,6 +128,7 @@ plot_gradientfield(N_stop, S, x₀, y, θ₁, θ₂, n)
170128
# ╔═╡ Cell order:
171129
# ╠═b72e9218-81ba-11f0-1eba-5bd949c7ade4
172130
# ╠═9f5c0822-a19a-4c63-95e7-d2f066a7440f
131+
# ╠═23a8503f-3c68-4523-aebe-a4ce4575a02b
173132
# ╠═a4453d23-6e31-451f-b2cd-97346accac82
174133
# ╠═bd0352c3-1b3c-42f5-ab93-7ca4cb67b9ad
175134
# ╟─df72e42f-7eec-476f-8ce5-72b09f620005
@@ -182,6 +141,5 @@ plot_gradientfield(N_stop, S, x₀, y, θ₁, θ₂, n)
182141
# ╠═45ee18f4-d6d3-40f4-bbc0-04cbd3b7b840
183142
# ╠═ae6a671d-1559-4bff-af6e-78d2b54db020
184143
# ╠═be852753-126d-42fa-a55c-c907f5dce99d
185-
# ╠═0b6d0456-1f94-479f-b690-89ad7fc61e44
186144
# ╠═873e7792-99a1-4472-92c2-6fc32e2889fa
187145
# ╠═d71a22cc-c1f3-4425-8a6f-442a0bc4f215

lib/EnzymeCore/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EnzymeCore"
22
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
4-
version = "0.8.13"
4+
version = "0.8.14"
55

66
[compat]
77
Adapt = "3, 4"
@@ -15,6 +15,7 @@ AdaptExt = "Adapt"
1515

1616
[extras]
1717
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
18+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1819

1920
[weakdeps]
2021
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

lib/EnzymeCore/src/EnzymeCore.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplic
66
export MixedDuplicated, BatchMixedDuplicated
77
export DefaultABI, FFIABI, InlineABI, NonGenABI
88
export BatchDuplicatedFunc
9-
export within_autodiff
9+
export within_autodiff, ignore_derivatives
1010
export needs_primal
1111

1212
function batch_size end
@@ -620,6 +620,27 @@ Returns true if within autodiff, otherwise false.
620620
return false
621621
end
622622

623+
"""
624+
ignore_derivatives(x::T)::T
625+
626+
Behaves like the `identity` function, but disconnects the "shadow"
627+
associated with `x`. This has the effect of preventing any derivatives
628+
from being propagated through `x`.
629+
630+
!!! compat "Enzyme 0.13.74"
631+
Support for `ignore_derivatives` was added in Enzyme 0.13.74.
632+
"""
633+
@generated function ignore_derivatives(x::T) where {T}
634+
name = "extern __enzyme_ignore_derivatives." * string(T)
635+
return quote
636+
if EnzymeCore.within_autodiff()
637+
return ccall($name, llvmcall, $T, ($T,), x)
638+
else
639+
return x
640+
end
641+
end
642+
end
643+
623644
"""
624645
set_err_if_func_written(::Mode)
625646

lib/EnzymeCore/test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,7 @@ end
4141
@testset "within_autodiff" begin
4242
@test !EnzymeCore.within_autodiff()
4343
end
44+
45+
@testset "ignore_derivatives" begin
46+
@test EnzymeCore.ignore_derivatives(3) == 3
47+
end

src/Enzyme.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ import EnzymeCore:
5353
set_strong_zero,
5454
clear_strong_zero,
5555
within_autodiff,
56+
ignore_derivatives,
5657
WithPrimal,
5758
NoPrimal,
5859
needs_primal,

src/compiler.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2604,6 +2604,29 @@ function enzyme!(
26042604
if DumpPostWrap[]
26052605
API.EnzymeDumpModuleRef(mod.ref)
26062606
end
2607+
2608+
# Rewrite enzyme_ignore_derivatives functions to the identity of their first argument.
2609+
to_delete = LLVM.Function[]
2610+
for fn in functions(mod)
2611+
if startswith(name(fn), "__enzyme_ignore_derivatives")
2612+
push!(to_delete, fn)
2613+
to_delete_inst = LLVM.CallInst[]
2614+
for u in LLVM.uses(fn)
2615+
ci = LLVM.user(u)
2616+
@assert isa(ci, LLVM.CallInst)
2617+
LLVM.replace_uses!(ci, operands(ci)[1])
2618+
push!(to_delete_inst, ci)
2619+
end
2620+
for ci in to_delete_inst
2621+
LLVM.erase!(ci)
2622+
end
2623+
end
2624+
end
2625+
for fn in to_delete
2626+
LLVM.erase!(fn)
2627+
end
2628+
LLVM.verify(mod)
2629+
26072630
API.EnzymeLogicErasePreprocessedFunctions(logic)
26082631
adjointfname = adjointf == nothing ? nothing : LLVM.name(adjointf)
26092632
augmented_primalfname =

test/ignore_derivatives.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using Test
2+
using Enzyme
3+
import Enzyme: ignore_derivatives
4+
5+
@testset "ignore_derivatives" begin
6+
@test autodiff(Enzyme.Forward, ignore_derivatives, Duplicated(1.0, 2.0)) == (0.0,)
7+
@test autodiff(Enzyme.Reverse, ignore_derivatives, Active(1.0)) == ((0.0,),)
8+
end
9+
10+
N(xᵢ, θ) = θ[1] * xᵢ^2 + θ[2] * xᵢ
11+
N_stop(xᵢ, θ) = θ[1] * ignore_derivatives(xᵢ^2) + θ[2] * ignore_derivatives(xᵢ)
12+
13+
@testset "simulate with ignore_derivatives" begin
14+
x₀ = -0.3
15+
θ = (-4.0, 4.0)
16+
17+
= MixedDuplicated(θ, Ref(Enzyme.make_zero(θ)))
18+
@test Enzyme.autodiff(Enzyme.Reverse, N, Active(x₀), dθ) == ((6.4, nothing),)
19+
@test.dval[] == (0.09, -0.3)
20+
21+
22+
= MixedDuplicated(θ, Ref(Enzyme.make_zero(θ)))
23+
@test Enzyme.autodiff(Enzyme.Reverse, N_stop, Active(x₀), dθ) == ((0.0, nothing),)
24+
@test.dval[] == (0.09, -0.3)
25+
end

0 commit comments

Comments
 (0)