You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
I'm wondering how other people are handling working with reactant precompiled functions.
In a training loop, I'm evaluating training and validation loss. Using reactant, I understand that function calls where I pass reactant arrays need to be precompiled before calling. In my training loop I have to precompile onehotbatch. In the function to estimate losses, I precompile onehotbatch, and the model itself.
Another approach is to calculate the losses on the cpu, like in the HyperNet tutorial.
function estimate_loss(model, ps, st, data_loader; max_iter=1000)
println("Estimating loss")
loss_total = 0.0
loss_fn = CrossEntropyLoss(; agg=mean, logits=Val(true))
# We are working with reactant. That means, we have to do some precompiling of the functions
# that are called with reactant arrays. Remember, ps, xb, and yb are now all reactant arrays!
(xb, yb) = first(data_loader)
onehotbatch_c = @compile onehotbatch(yb, 1:65)
model_c = @compile model(xb, ps, st)
for (ix_b, batch) in ProgressBar(enumerate(data_loader))
ix_b > max_iter && break
xb, yb = batch
yb_hot = onehotbatch_c(yb, 1:65)
y_pred, _ = model_c(xb, ps, st)
loss_total += loss_fn(y_pred, yb_hot)
end
return loss_total / max_iter
end
function train(tstate::Training.TrainState, vjp, loader_train, loader_valid, num_iter)
println("Training")
loss_iter = 0.0
time_start = time_ns()
(xb, yb) = first(loader_train)
onehotbatch_c = @compile onehotbatch(yb, 1:65)
for (ix_it, batch) in enumerate(loader_train)
ix_it > num_iter && break
xb, yb = batch
yb_hot = onehotbatch_c(yb, 1:65)
_, loss, _, tstate = Training.single_train_step!(vjp, CrossEntropyLoss(; agg=mean, logits=Val(true)), (xb, yb_hot), tstate)
loss_iter += loss
if mod(ix_it, 100) == 0
time_stop = time_ns()
ps_c = tstate.parameters
st_c = Lux.testmode(tstate.states)
loss_train = estimate_loss(tstate.model, ps_c, st_c, loader_train)
loss_valid = estimate_loss(tstate.model, ps_c, st_c, loader_valid)
time_elapsed_s = (time_stop - time_start) / 1e9
println("Iteration $(ix_it). Training loss: $(loss_iter / num_iter). Validation loss: $(loss_valid). Elapsed time: $(round(time_elapsed_s, digits=4))s")
time_start = time_ns()
end
end
loss_total = loss_iter / num_iter
println("----- Final loss: $(loss_total)")
return tstate
end
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I'm wondering how other people are handling working with reactant precompiled functions.
In a training loop, I'm evaluating training and validation loss. Using reactant, I understand that function calls where I pass reactant arrays need to be precompiled before calling. In my training loop I have to precompile
onehotbatch. In the function to estimate losses, I precompileonehotbatch, and the model itself.Another approach is to calculate the losses on the cpu, like in the HyperNet tutorial.
Beta Was this translation helpful? Give feedback.
All reactions