-
Notifications
You must be signed in to change notification settings - Fork 90
Open
Description
using Enzyme
using LinearAlgebra
using SparseArrays
using Test
Enzyme.API.printall!(true)
Enzyme.API.printactivity!(true)
# Wrapper similar to EnzymeTestUtils
@inline function call_wrapped(f::FT, args...) where FT
@inline f(args...)
return nothing
end
for N in 1:5
argexprs = [Symbol(:arg, Symbol(i)) for i in 1:N]
eval(quote
function call_wrapped(f::FT, $(argexprs...)) where {FT}
Base.@_inline_meta
@inline f($(argexprs...))
end
end)
end
# Enzyme.Compiler.set_fn_max_args(call_wrapped)
function run_autodiff_test()
Ts = ComplexF64
M0 = [
0.0 1.50614;
0.0 -0.988357;
0.0 0.0
]
M = SparseMatrixCSC((M0 .+ 2im * M0))
v = rand(Ts, 2)
α = rand(Ts)
β = rand(Ts)
C = zeros(Ts, size(M, 1))
# Activities
dC = zeros(Ts, size(C))
act_C = Duplicated(copy(C), dC)
act_M = Const(M)
dv = zeros(Ts, size(v))
act_v = Duplicated(copy(v), dv)
act_α = Active(α)
act_β = Active(β)
println("Running autodiff with wrapper...")
try
# Using Duplicated as the return activity to match thunk case
# Note: autodiff signature: (mode, f, return_activity_type, args...)
Enzyme.autodiff(Reverse, call_wrapped, Const, Const(mul!), act_C, act_M, act_v, Const(1.0), act_β)
println("Success with wrapped autodiff!")
catch e
showerror(stdout, e)
println()
end
end
run_autodiff_test()
on 1.12, using EnzymeAD/Enzyme#2684
this passes with the macro expand stuff, but fails when using set_fn_max_args
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels