Skip to content

Commit 90a4335

Browse files
committed
make loss(f,x,y) work
1 parent 15c8590 commit 90a4335

File tree

3 files changed

+71
-7
lines changed

3 files changed

+71
-7
lines changed

docs/src/models/losses.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,27 @@ Flux provides a large number of common loss functions used for training machine
44
They are grouped together in the `Flux.Losses` module.
55

66
Loss functions for supervised learning typically expect as inputs a target `y`, and a prediction `` from your model.
7-
In Flux's convention, the order of the arguments is the following
7+
In Flux's convention, the target is the last argumemt:
88

99
```julia
1010
loss(ŷ, y)
1111
```
1212

13+
All loss functions have a method which takes the model as the first argument, and calculates the prediction `ŷ = model(x)`.
14+
This is convenient for [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`:
15+
16+
```julia
17+
loss(model, x, y) = loss(model(x), y)
18+
```
19+
1320
Most loss functions in Flux have an optional argument `agg`, denoting the type of aggregation performed over the
1421
batch:
1522

1623
```julia
1724
loss(ŷ, y) # defaults to `mean`
18-
loss(ŷ, y, agg=sum) # use `sum` for reduction
19-
loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction
25+
loss(ŷ, y, agg=sum) # use `sum` instead
2026
loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean
21-
loss(ŷ, y, agg=identity) # no aggregation.
27+
loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction, returns an array
2228
```
2329

2430
### Function listing

src/losses/Losses.jl

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,28 @@
1+
"""
2+
Flux.Losses
3+
4+
This sub-module contains many loss functions, all of which accept two arguments,
5+
with the model output as the fist argument: `loss(model(x), y)`.
6+
It also contains a few related utilities, such as `label_smoothing`.
7+
The complete list of exports is:
8+
9+
label_smoothing,
10+
mse, mae, msle,
11+
crossentropy,
12+
logitcrossentropy,
13+
binarycrossentropy,
14+
logitbinarycrossentropy,
15+
kldivergence,
16+
huber_loss,
17+
tversky_loss,
18+
dice_coeff_loss,
19+
poisson_loss,
20+
hinge_loss,
21+
squared_hinge_loss,
22+
binary_focal_loss,
23+
focal_loss,
24+
siamese_contrastive_loss
25+
"""
126
module Losses
227

328
using Statistics
@@ -9,8 +34,8 @@ using CUDA
934
using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss
1035
import Base.Broadcast: broadcasted
1136

12-
export mse, mae, msle,
13-
label_smoothing,
37+
export label_smoothing,
38+
mse, mae, msle,
1439
crossentropy, logitcrossentropy,
1540
binarycrossentropy, logitbinarycrossentropy,
1641
kldivergence,
@@ -19,9 +44,33 @@ export mse, mae, msle,
1944
dice_coeff_loss,
2045
poisson_loss,
2146
hinge_loss, squared_hinge_loss,
22-
binary_focal_loss, focal_loss, siamese_contrastive_loss
47+
binary_focal_loss, focal_loss,
48+
siamese_contrastive_loss
2349

2450
include("utils.jl")
2551
include("functions.jl")
2652

53+
for loss in Symbol.([
54+
mse, mae, msle,
55+
crossentropy, logitcrossentropy,
56+
binarycrossentropy, logitbinarycrossentropy,
57+
kldivergence,
58+
huber_loss,
59+
tversky_loss,
60+
dice_coeff_loss,
61+
poisson_loss,
62+
hinge_loss, squared_hinge_loss,
63+
binary_focal_loss, focal_loss,
64+
siamese_contrastive_loss,
65+
])
66+
@eval begin
67+
"""
68+
$($loss)(model, x, y)
69+
70+
This method calculates `ŷ = model(x)`. Accepts the same keyword arguments.
71+
"""
72+
$loss(f, x::AbstractArray, y::AbstractArray; kw...) = $loss(f(x), y; kw...)
73+
end
74+
end
75+
2776
end #module

test/losses.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,12 @@ end
248248
@test_throws DomainError(-0.5, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ1, y1, margin = -0.5)
249249
@test_throws DomainError(-1, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ, y, margin = -1)
250250
end
251+
252+
@testset "3-arg methods" begin
253+
@testset for loss in ALL_LOSSES
254+
fun(x) = x[1:2]
255+
x = rand(3)
256+
y = rand(2)
257+
@test loss(fun, x, y) == loss(fun(x), y)
258+
end
259+
end

0 commit comments

Comments
 (0)