-
-
Notifications
You must be signed in to change notification settings - Fork 216
Closed
Description
I have a package that defines a simple utility Array type : https://github.com/maartenvd/MPSKit.jl/blob/diskarray/src/utility/periodicarray.jl
Zygote changes types when taking the derivative, which later on makes my backward rules fail. Here is a minimal example:
julia> f_add(x) = x + 3;
julia> function myfun(x)
y = f_add.(x);
@show typeof(y)
norm(y)
end
myfun (generic function with 1 method)
julia> myfun'(PeriodicArray(rand(5,5)));
typeof(y) = Matrix{Float64}
julia> myfun(PeriodicArray(rand(5,5)));
typeof(y) = PeriodicArray{Float64, 2}
This type change causes failures, as it then calls rrule with a tangent type of PeriodicArray, but a (wrong) primal type of Matrix
Metadata
Metadata
Assignees
Labels
No labels