-
Notifications
You must be signed in to change notification settings - Fork 83
Open
Labels
Description
Hello :)
This is a simplified version of a part of a model that I'm building. It's a parallel layer of identical submodels. Each takes an image and returns a number.
using Reactant, Lux, MLDataDevices, Random, Optimisers, Enzyme
Reactant.set_default_backend("gpu")
function make_branch()
ls = []
push!(ls, Lux.Conv((3, 3), 3 => 18, use_bias = false))
push!(ls, Lux.Conv((3, 3), 18 => 36, use_bias = false))
push!(ls, Lux.Conv((3, 3), 36 => 72, use_bias = false))
push!(ls, Lux.AdaptiveMeanPool((2, 2)))
push!(ls, Lux.MaxPool((2, 2)))
push!(ls, Lux.FlattenLayer())
push!(ls, Lux.Dense(72 => 32))
push!(ls, Lux.Dense(32 => 18))
push!(ls, Lux.Dense(18 => 1))
return Lux.Chain(ls...)
end
parallel_layer = Lux.Parallel(
(xs...) -> cat(xs..., dims = 1),
make_branch(),
make_branch(),
make_branch(),
make_branch(),
make_branch(),
make_branch(),
) # should be (6, B)
Outputs from the models are concatenated.
For this simple model (200k params). Compilation takes more than 10 minutes.
dev = reactant_device(force = true)
ps, st = Lux.setup(Random.default_rng(), parallel_layer)
struct CustomLoss <: Lux.AbstractLossFunction
end
function (::CustomLoss)(
model::Lux.AbstractLuxLayer, ps, st::NamedTuple, data
)
x, y = data
pred, new_st = model(x, ps, st)
res = sum(abs2.(pred - y))
return res, new_st, NamedTuple()
end
loss_fn = CustomLoss() # Maybe there is the same issue with a default MSEloss. Haven't tried it.
optimizer = Adam()
x, y = rand(Float32, 32, 32, 3, 256), rand(Float32, 6, 256)
# Reactant
x, y = x |> dev, y |> dev
ps, st = ps |> dev, st |> dev
@time train_state = Lux.Training.TrainState(parallel_layer, ps, st, optimizer);
@time grads, loss, pack, train_state = Lux.Training.single_train_step(AutoEnzyme(; mode = Reverse), loss_fn, (x, y), train_state);
Making the Train State was ok in time :
julia> @time train_state = Lux.Training.TrainState(parallel_layer, ps, st, optimizer);
41.999941 seconds (29.63 M allocations: 1.543 GiB, 0.80% gc time, 98.90% compilation time: 1% of which was recompilation)
But compiling the update step was very slow (11 mins):
julia> @time grads, loss, pack, train_state = Lux.Training.single_train_step(AutoEnzyme(; mode = Reverse), loss_fn, (x, y), train_state);
700.861117 seconds (846.38 M allocations: 43.138 GiB, 0.97% gc time, 98.30% compilation time: <1% of which was recompilation)
With a larger example I hit more than 20 minutes and with the model I want to run, I've never seen it finish compilation.
Reactions are currently unavailable