Skip to content

Match Flux's support of arbitrary loss functions #143

@Vilin97

Description

@Vilin97

In Flux/Zygote I can use any loss function, whereas SimpleChains only allow absolute loss, squared loss and cross-entropy loss (am I perhaps wrong?). What is the reason that an arbitrary loss cannot be used? I would want something like

using SimpleChains
y = rand(Float32, 3, 2^8)
s = SimpleChain(
  static(3),
  TurboDense{true}(softsign, 100),
  TurboDense{true}(softsign, 3)
);
my_loss_fun(arg) = sum(i -> (arg[i] - y[i])^2, eachindex(y))/sum(abs2, y) 
my_loss = SimpleChains.Loss(my_loss_fun, y) # proposed syntax
train_loss = SimpleChains.add_loss(s, my_loss) # proposed syntax

Is it possible to do something like this? What would it take to use my own loss with SimpleChains?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions