Skip to content

Restructure is not type stable but could be made stable? #177

@Red-Portal

Description

@Red-Portal

Hi, it seems like calling restructure is not stable by default. This is currently causing issues with Enzyme.jl (see this issue). Here is a MWE to illustrate the point:

using Cthulhu, LinearAlgebra, Optimisers, Functors

struct Model{A,B}
    a::A
    b::B
end

Functors.@functor Model

m = Model(randn(10), LowerTriangular(Matrix(I, 10, 10)))

params, re = Optimisers.destructure(m)

@code_warntype re(params)

This returns:

MethodInstance for (::Optimisers.Restructure{Model{Vector{Float64}, LowerTriangular{Bool, Matrix{Bool}}}, @NamedTuple{a::Int64, b::Tuple{}}})(::Vector{Float64})
  from (re::Optimisers.Restructure)(flat::AbstractVector) @ Optimisers ~/.julia/packages/Optimisers/yDIWk/src/destructure.jl:59
Arguments
  re::Optimisers.Restructure{Model{Vector{Float64}, LowerTriangular{Bool, Matrix{Bool}}}, @NamedTuple{a::Int64, b::Tuple{}}}
  flat::Vector{Float64}
Body::Model
1%1 = Base.getproperty(re, :model)::Model{Vector{Float64}, LowerTriangular{Bool, Matrix{Bool}}}%2 = Base.getproperty(re, :offsets)::@NamedTuple{a::Int64, b::Tuple{}}
│   %3 = Base.getproperty(re, :length)::Int64%4 = Optimisers._rebuild(%1, %2, flat, %3)::Model
└──      return %4

where the return type Model is not stable in terms of its type parameters.

This can be solved in a brute-force manner by defining re::Restructure (defined here) as:

   (re::Restructure)(flat::AbstractVector)::typeof(re.model) = _rebuild(re.model, re.offsets, flat, re.length)

where we are informing the compiler that the return type will be the same as re.model. I think this is safe to assume, and this immediately resolves the instability. Any thoughts on this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions