-
Notifications
You must be signed in to change notification settings - Fork 13
Open
Description
In Flux/Zygote I can use any loss function, whereas SimpleChains only allow absolute loss, squared loss and cross-entropy loss (am I perhaps wrong?). What is the reason that an arbitrary loss cannot be used? I would want something like
using SimpleChains
y = rand(Float32, 3, 2^8)
s = SimpleChain(
static(3),
TurboDense{true}(softsign, 100),
TurboDense{true}(softsign, 3)
);
my_loss_fun(arg) = sum(i -> (arg[i] - y[i])^2, eachindex(y))/sum(abs2, y)
my_loss = SimpleChains.Loss(my_loss_fun, y) # proposed syntax
train_loss = SimpleChains.add_loss(s, my_loss) # proposed syntax
Is it possible to do something like this? What would it take to use my own loss with SimpleChains?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels