-
-
Notifications
You must be signed in to change notification settings - Fork 24
Closed
Description
Hi, it seems like calling restructure is not stable by default. This is currently causing issues with Enzyme.jl
(see this issue). Here is a MWE to illustrate the point:
using Cthulhu, LinearAlgebra, Optimisers, Functors
struct Model{A,B}
a::A
b::B
end
Functors.@functor Model
m = Model(randn(10), LowerTriangular(Matrix(I, 10, 10)))
params, re = Optimisers.destructure(m)
@code_warntype re(params)
This returns:
MethodInstance for (::Optimisers.Restructure{Model{Vector{Float64}, LowerTriangular{Bool, Matrix{Bool}}}, @NamedTuple{a::Int64, b::Tuple{}}})(::Vector{Float64})
from (re::Optimisers.Restructure)(flat::AbstractVector) @ Optimisers ~/.julia/packages/Optimisers/yDIWk/src/destructure.jl:59
Arguments
re::Optimisers.Restructure{Model{Vector{Float64}, LowerTriangular{Bool, Matrix{Bool}}}, @NamedTuple{a::Int64, b::Tuple{}}}
flat::Vector{Float64}
Body::Model
1 ─ %1 = Base.getproperty(re, :model)::Model{Vector{Float64}, LowerTriangular{Bool, Matrix{Bool}}}
│ %2 = Base.getproperty(re, :offsets)::@NamedTuple{a::Int64, b::Tuple{}}
│ %3 = Base.getproperty(re, :length)::Int64
│ %4 = Optimisers._rebuild(%1, %2, flat, %3)::Model
└── return %4
where the return type Model
is not stable in terms of its type parameters.
This can be solved in a brute-force manner by defining re::Restructure
(defined here) as:
(re::Restructure)(flat::AbstractVector)::typeof(re.model) = _rebuild(re.model, re.offsets, flat, re.length)
where we are informing the compiler that the return type will be the same as re.model
. I think this is safe to assume, and this immediately resolves the instability. Any thoughts on this?
Metadata
Metadata
Assignees
Labels
No labels