Skip to content

Commit 0e651e4

Browse files
authored
Enzyme: simplify via mixedduplicated (#483)
1 parent b1d557b commit 0e651e4

File tree

2 files changed

+3
-33
lines changed

2 files changed

+3
-33
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
2121
[compat]
2222
Adapt = "0.4, 1.0, 2.0, 3.0, 4"
2323
Atomix = "0.1"
24-
EnzymeCore = "0.7.1"
24+
EnzymeCore = "0.7.5"
2525
InteractiveUtils = "1.6"
2626
LinearAlgebra = "1.6"
2727
MacroTools = "0.5"

ext/EnzymeExt.jl

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -70,25 +70,6 @@ module EnzymeExt
7070
fwd_kernel(f, args...; ndrange, workgroupsize)
7171
end
7272

73-
74-
@inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys}
75-
if !any(ActiveTys)
76-
return f
77-
end
78-
function inact(ctx, args2::Vararg{Any, N}) where N
79-
args3 = ntuple(Val(N)) do i
80-
Base.@_inline_meta
81-
if ActiveTys[i]
82-
args2[i][]
83-
else
84-
args2[i]
85-
end
86-
end
87-
f(ctx, args3...)
88-
end
89-
return inact
90-
end
91-
9273
function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
9374
kernel = func.val
9475
f = kernel.f
@@ -102,11 +83,6 @@ module EnzymeExt
10283
# TODO autodiff_deferred on the func.val
10384
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
10485

105-
tup = Val(ntuple(Val(N)) do i
106-
Base.@_inline_meta
107-
args[i] isa Active
108-
end)
109-
f = make_active_byref(f, tup)
11086
FT = Const{Core.Typeof(f)}
11187

11288
arg_refs = ntuple(Val(N)) do i
@@ -120,7 +96,7 @@ module EnzymeExt
12096
args2 = ntuple(Val(N)) do i
12197
Base.@_inline_meta
12298
if args[i] isa Active
123-
Duplicated(Ref(args[i].val), arg_refs[i])
99+
EnzymeCore.MixedDuplicated(args[i].val, arg_refs[i])
124100
else
125101
args[i]
126102
end
@@ -150,7 +126,7 @@ module EnzymeExt
150126
args2 = ntuple(Val(N)) do i
151127
Base.@_inline_meta
152128
if args[i] isa Active
153-
Duplicated(Ref(args[i].val), arg_refs[i])
129+
EnzymeCore.MixedDuplicated(args[i].val, arg_refs[i])
154130
else
155131
args[i]
156132
end
@@ -159,12 +135,6 @@ module EnzymeExt
159135
kernel = func.val
160136
f = kernel.f
161137

162-
tup = Val(ntuple(Val(N)) do i
163-
Base.@_inline_meta
164-
args[i] isa Active
165-
end)
166-
f = make_active_byref(f, tup)
167-
168138
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
169139

170140
rev_kernel = similar(func.val, rev_cpu)

0 commit comments

Comments
 (0)