Skip to content

Very Long compile time for a simple parallel layer with Reactant, Enzyme and GPU backend #1484

@camilodlt

Description

@camilodlt

Hello :)

This is a simplified version of a part of a model that I'm building. It's a parallel layer of identical submodels. Each takes an image and returns a number.

using Reactant, Lux, MLDataDevices, Random, Optimisers, Enzyme
Reactant.set_default_backend("gpu")

function make_branch()
    ls = []
    push!(ls, Lux.Conv((3, 3), 3 => 18, use_bias = false))
    push!(ls, Lux.Conv((3, 3), 18 => 36, use_bias = false))
    push!(ls, Lux.Conv((3, 3), 36 => 72, use_bias = false))
    push!(ls, Lux.AdaptiveMeanPool((2, 2)))
    push!(ls, Lux.MaxPool((2, 2)))
    push!(ls, Lux.FlattenLayer())
    push!(ls, Lux.Dense(72 => 32))
    push!(ls, Lux.Dense(32 => 18))
    push!(ls, Lux.Dense(18 => 1))
    return Lux.Chain(ls...)
end

parallel_layer = Lux.Parallel(
    (xs...) -> cat(xs..., dims = 1),
    make_branch(),
    make_branch(),
    make_branch(),
    make_branch(),
    make_branch(),
    make_branch(),
) # should be (6, B)

Outputs from the models are concatenated.

For this simple model (200k params). Compilation takes more than 10 minutes.

dev = reactant_device(force = true)
ps, st = Lux.setup(Random.default_rng(), parallel_layer)

struct CustomLoss <: Lux.AbstractLossFunction
end
function (::CustomLoss)(
        model::Lux.AbstractLuxLayer, ps, st::NamedTuple, data
    )
    x, y = data
    pred, new_st = model(x, ps, st)
    res = sum(abs2.(pred - y))
    return res, new_st, NamedTuple()
end
loss_fn = CustomLoss() # Maybe there is the same issue with a default MSEloss. Haven't tried it. 

optimizer = Adam()
x, y = rand(Float32, 32, 32, 3, 256), rand(Float32, 6, 256)

# Reactant
x, y = x |> dev, y |> dev
ps, st = ps |> dev, st |> dev
@time train_state = Lux.Training.TrainState(parallel_layer, ps, st, optimizer);
@time grads, loss, pack, train_state = Lux.Training.single_train_step(AutoEnzyme(; mode = Reverse), loss_fn, (x, y), train_state);

Making the Train State was ok in time :

julia> @time train_state = Lux.Training.TrainState(parallel_layer, ps, st, optimizer);
 41.999941 seconds (29.63 M allocations: 1.543 GiB, 0.80% gc time, 98.90% compilation time: 1% of which was recompilation)

But compiling the update step was very slow (11 mins):

julia> @time grads, loss, pack, train_state = Lux.Training.single_train_step(AutoEnzyme(; mode = Reverse), loss_fn, (x, y), train_state);
700.861117 seconds (846.38 M allocations: 43.138 GiB, 0.97% gc time, 98.30% compilation time: <1% of which was recompilation)

With a larger example I hit more than 20 minutes and with the model I want to run, I've never seen it finish compilation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions