Skip to content

Commit 8ce74e1

Browse files
authored
Add log_cosh and huber loss (keras-team#67)
* Add log_cosh and huber loss * Docstring standardization * Format * Standardize wrapper function docstrings
1 parent b1b1a4b commit 8ce74e1

File tree

2 files changed

+348
-1
lines changed

2 files changed

+348
-1
lines changed

keras_core/losses/losses.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,63 @@ def __init__(
183183
)
184184

185185

186+
@keras_core_export("keras_core.losses.Huber")
187+
class Huber(LossFunctionWrapper):
188+
"""Computes the Huber loss between `y_true` & `y_pred`.
189+
190+
Formula:
191+
```python
192+
for x in error:
193+
if abs(x) <= delta:
194+
loss.append(0.5 * x^2)
195+
elif abs(x) > delta:
196+
loss.append(delta * abs(x) - 0.5 * delta^2)
197+
198+
loss = mean(loss, axis=-1)
199+
```
200+
See: [Huber loss](https://en.wikipedia.org/wiki/Huber_loss).
201+
202+
Args:
203+
delta: A float, the point where the Huber loss function changes from a
204+
quadratic to linear.
205+
reduction: Type of reduction to apply to loss. Options are `"sum"`,
206+
`"sum_over_batch_size"` or `None`. Defaults to
207+
`"sum_over_batch_size"`.
208+
name: Optional name for the instance.
209+
"""
210+
211+
def __init__(
212+
self,
213+
delta=1.0,
214+
reduction="sum_over_batch_size",
215+
name="huber_loss",
216+
):
217+
super().__init__(huber, name=name, reduction=reduction, delta=delta)
218+
219+
220+
@keras_core_export("keras_core.losses.LogCosh")
221+
class LogCosh(LossFunctionWrapper):
222+
"""Computes the logarithm of the hyperbolic cosine of the prediction error.
223+
224+
Formula:
225+
226+
```python
227+
error = y_pred - y_true
228+
logcosh = log((exp(error) + exp(-error))/2)`
229+
```
230+
where x is the error `y_pred - y_true`.
231+
232+
Args:
233+
reduction: Type of reduction to apply to loss. Options are `"sum"`,
234+
`"sum_over_batch_size"` or `None`. Defaults to
235+
`"sum_over_batch_size"`.
236+
name: Optional name for the instance.
237+
"""
238+
239+
def __init__(self, reduction="sum_over_batch_size", name="log_cosh"):
240+
super().__init__(log_cosh, name=name, reduction=reduction)
241+
242+
186243
@keras_core_export("keras_core.losses.Hinge")
187244
class Hinge(LossFunctionWrapper):
188245
"""Computes the hinge loss between `y_true` & `y_pred`.
@@ -1063,7 +1120,7 @@ def mean_squared_error(y_true, y_pred):
10631120
loss = mean(square(y_true - y_pred), axis=-1)
10641121
```
10651122
1066-
Standalone usage:
1123+
Example:
10671124
10681125
>>> y_true = np.random.randint(0, 2, size=(2, 3))
10691126
>>> y_pred = np.random.random(size=(2, 3))
@@ -1237,6 +1294,97 @@ def cosine_similarity(y_true, y_pred, axis=-1):
12371294
return -ops.sum(y_true * y_pred, axis=axis)
12381295

12391296

1297+
@keras_core_export(["keras_core.losses.huber", "keras_core.metrics.huber"])
1298+
def huber(y_true, y_pred, delta=1.0):
1299+
"""Computes Huber loss value.
1300+
1301+
Formula:
1302+
```python
1303+
for x in error:
1304+
if abs(x) <= delta:
1305+
loss.append(0.5 * x^2)
1306+
elif abs(x) > delta:
1307+
loss.append(delta * abs(x) - 0.5 * delta^2)
1308+
1309+
loss = mean(loss, axis=-1)
1310+
```
1311+
See: [Huber loss](https://en.wikipedia.org/wiki/Huber_loss).
1312+
1313+
Example:
1314+
1315+
>>> y_true = [[0, 1], [0, 0]]
1316+
>>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
1317+
>>> loss = keras_core.losses.huber(y_true, y_pred)
1318+
0.155
1319+
1320+
1321+
Args:
1322+
y_true: tensor of true targets.
1323+
y_pred: tensor of predicted targets.
1324+
delta: A float, the point where the Huber loss function changes from a
1325+
quadratic to linear. Defaults to 1.
1326+
1327+
Returns:
1328+
Tensor with one scalar loss entry per sample.
1329+
"""
1330+
y_pred = ops.convert_to_tensor(y_pred)
1331+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
1332+
y_true, y_pred = squeeze_to_same_rank(y_true, y_pred)
1333+
delta = ops.convert_to_tensor(delta)
1334+
error = ops.subtract(y_pred, y_true)
1335+
abs_error = ops.abs(error)
1336+
half = ops.convert_to_tensor(0.5, dtype=abs_error.dtype)
1337+
return ops.mean(
1338+
ops.where(
1339+
abs_error <= delta,
1340+
half * ops.square(error),
1341+
delta * abs_error - half * ops.square(delta),
1342+
),
1343+
axis=-1,
1344+
)
1345+
1346+
1347+
@keras_core_export(
1348+
["keras_core.losses.log_cosh", "keras_core.metrics.log_cosh"]
1349+
)
1350+
def log_cosh(y_true, y_pred):
1351+
"""Logarithm of the hyperbolic cosine of the prediction error.
1352+
1353+
Formula:
1354+
```python
1355+
loss = mean(log(cosh(y_pred - y_true)), axis=-1)
1356+
```
1357+
1358+
Note that `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small
1359+
`x` and to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works
1360+
mostly like the mean squared error, but will not be so strongly affected by
1361+
the occasional wildly incorrect prediction.
1362+
1363+
Example:
1364+
1365+
>>> y_true = [[0., 1.], [0., 0.]]
1366+
>>> y_pred = [[1., 1.], [0., 0.]]
1367+
>>> loss = keras_core.losses.log_cosh(y_true, y_pred)
1368+
0.108
1369+
1370+
Args:
1371+
y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.
1372+
y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
1373+
1374+
Returns:
1375+
Logcosh error values with shape = `[batch_size, d0, .. dN-1]`.
1376+
"""
1377+
y_pred = ops.convert_to_tensor(y_pred)
1378+
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
1379+
y_true, y_pred = squeeze_to_same_rank(y_true, y_pred)
1380+
log2 = ops.convert_to_tensor(ops.log(2.0), dtype=y_pred.dtype)
1381+
1382+
def _logcosh(x):
1383+
return x + ops.softplus(-2.0 * x) - log2
1384+
1385+
return ops.mean(_logcosh(y_pred - y_true), axis=-1)
1386+
1387+
12401388
@keras_core_export(
12411389
[
12421390
"keras_core.metrics.kl_divergence",

keras_core/losses/losses_test.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,205 @@ def test_axis(self):
465465
self.assertAlmostEqual(loss, expected_loss, 3)
466466

467467

468+
class HuberLossTest(testing.TestCase):
469+
def huber_loss(self, y_true, y_pred, delta=1.0):
470+
error = y_pred - y_true
471+
abs_error = np.abs(error)
472+
473+
quadratic = np.minimum(abs_error, delta)
474+
linear = np.subtract(abs_error, quadratic)
475+
return np.add(
476+
np.multiply(0.5, np.multiply(quadratic, quadratic)),
477+
np.multiply(delta, linear),
478+
)
479+
480+
def setup(self, delta=1.0):
481+
self.np_y_pred = np.array([[0.9, 0.2, 0.2], [0.8, 0.4, 0.6]])
482+
self.np_y_true = np.array([[1.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
483+
484+
self.batch_size = 6
485+
self.expected_losses = self.huber_loss(
486+
self.np_y_true, self.np_y_pred, delta
487+
)
488+
489+
self.y_pred = self.np_y_pred
490+
self.y_true = self.np_y_true
491+
492+
def test_config(self):
493+
h_obj = losses.Huber(reduction="sum", name="huber")
494+
self.assertEqual(h_obj.name, "huber")
495+
self.assertEqual(h_obj.reduction, "sum")
496+
497+
def test_all_correct(self):
498+
self.setup()
499+
h_obj = losses.Huber()
500+
loss = h_obj(self.y_true, self.y_true)
501+
self.assertAlmostEqual(loss, 0.0, 3)
502+
503+
def test_unweighted(self):
504+
self.setup()
505+
h_obj = losses.Huber()
506+
loss = h_obj(self.y_true, self.y_pred)
507+
actual_loss = np.sum(self.expected_losses) / self.batch_size
508+
self.assertAlmostEqual(loss, actual_loss, 3)
509+
510+
def test_scalar_weighted(self):
511+
self.setup()
512+
h_obj = losses.Huber()
513+
sample_weight = 2.3
514+
loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
515+
actual_loss = (
516+
sample_weight * np.sum(self.expected_losses) / self.batch_size
517+
)
518+
self.assertAlmostEqual(loss, actual_loss, 3)
519+
520+
# Verify we get the same output when the same input is given
521+
loss_2 = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
522+
self.assertAlmostEqual(loss, loss_2, 3)
523+
524+
def test_sample_weighted(self):
525+
self.setup()
526+
h_obj = losses.Huber()
527+
sample_weight = np.array([[1.2], [3.4]])
528+
529+
loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
530+
actual_loss = np.multiply(
531+
self.expected_losses,
532+
np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)),
533+
)
534+
actual_loss = np.sum(actual_loss) / self.batch_size
535+
self.assertAlmostEqual(loss, actual_loss, 3)
536+
537+
def test_timestep_weighted(self):
538+
self.setup()
539+
h_obj = losses.Huber()
540+
y_pred = self.np_y_pred.reshape((2, 3, 1))
541+
y_true = self.np_y_true.reshape((2, 3, 1))
542+
expected_losses = self.huber_loss(y_true, y_pred)
543+
544+
sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3, 1))
545+
loss = h_obj(
546+
y_true,
547+
y_pred,
548+
sample_weight=sample_weight,
549+
)
550+
actual_loss = np.multiply(expected_losses, sample_weight)
551+
actual_loss = np.sum(actual_loss) / self.batch_size
552+
self.assertAlmostEqual(loss, actual_loss, 3)
553+
554+
def test_zero_weighted(self):
555+
self.setup()
556+
h_obj = losses.Huber()
557+
sample_weight = 0
558+
loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
559+
self.assertAlmostEqual(loss, 0.0, 3)
560+
561+
def test_non_default_delta(self):
562+
self.setup(delta=0.8)
563+
h_obj = losses.Huber(delta=0.8)
564+
sample_weight = 2.3
565+
loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
566+
actual_loss = (
567+
sample_weight * np.sum(self.expected_losses) / self.batch_size
568+
)
569+
self.assertAlmostEqual(loss, actual_loss, 3)
570+
571+
def test_loss_with_non_default_dtype(self):
572+
# Test case for GitHub issue:
573+
# https://github.com/tensorflow/tensorflow/issues/39004
574+
# TODO
575+
pass
576+
577+
578+
class LogCoshTest(testing.TestCase):
579+
def setup(self):
580+
y_true = np.asarray([[1, 9, 2], [-5, -2, 6]], dtype=np.float32)
581+
y_pred = np.asarray([[4, 8, 12], [8, 1, 3]], dtype=np.float32)
582+
583+
self.batch_size = 6
584+
error = y_pred - y_true
585+
self.expected_losses = np.log((np.exp(error) + np.exp(-error)) / 2)
586+
587+
self.y_true = y_true
588+
self.y_pred = y_pred
589+
590+
def test_config(self):
591+
logcosh_obj = losses.LogCosh(reduction="sum", name="logcosh_loss")
592+
self.assertEqual(logcosh_obj.name, "logcosh_loss")
593+
self.assertEqual(logcosh_obj.reduction, "sum")
594+
595+
def test_unweighted(self):
596+
self.setup()
597+
logcosh_obj = losses.LogCosh()
598+
599+
loss = logcosh_obj(self.y_true, self.y_pred)
600+
expected_loss = np.sum(self.expected_losses) / self.batch_size
601+
self.assertAlmostEqual(loss, expected_loss, 3)
602+
603+
def test_scalar_weighted(self):
604+
self.setup()
605+
logcosh_obj = losses.LogCosh()
606+
sample_weight = 2.3
607+
608+
loss = logcosh_obj(
609+
self.y_true, self.y_pred, sample_weight=sample_weight
610+
)
611+
expected_loss = (
612+
sample_weight * np.sum(self.expected_losses) / self.batch_size
613+
)
614+
self.assertAlmostEqual(loss, expected_loss, 3)
615+
616+
# Verify we get the same output when the same input is given
617+
loss_2 = logcosh_obj(
618+
self.y_true, self.y_pred, sample_weight=sample_weight
619+
)
620+
self.assertAlmostEqual(loss, loss_2, 3)
621+
622+
def test_sample_weighted(self):
623+
self.setup()
624+
logcosh_obj = losses.LogCosh()
625+
626+
sample_weight = np.asarray([1.2, 3.4])
627+
loss = logcosh_obj(
628+
self.y_true, self.y_pred, sample_weight=sample_weight
629+
)
630+
631+
expected_loss = np.multiply(
632+
self.expected_losses,
633+
np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)),
634+
)
635+
expected_loss = np.sum(expected_loss) / self.batch_size
636+
self.assertAlmostEqual(loss, expected_loss, 3)
637+
638+
def test_timestep_weighted(self):
639+
self.setup()
640+
logcosh_obj = losses.LogCosh()
641+
y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)
642+
y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)
643+
error = y_pred - y_true
644+
expected_losses = np.log((np.exp(error) + np.exp(-error)) / 2)
645+
sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3, 1))
646+
647+
loss = logcosh_obj(
648+
y_true,
649+
y_pred,
650+
sample_weight=sample_weight,
651+
)
652+
expected_loss = (
653+
np.sum(expected_losses * sample_weight) / self.batch_size
654+
)
655+
self.assertAlmostEqual(loss, expected_loss, 3)
656+
657+
def test_zero_weighted(self):
658+
self.setup()
659+
logcosh_obj = losses.LogCosh()
660+
sample_weight = 0
661+
loss = logcosh_obj(
662+
self.y_true, self.y_pred, sample_weight=sample_weight
663+
)
664+
self.assertAlmostEqual(loss, 0.0, 3)
665+
666+
468667
class KLDivergenceTest(testing.TestCase):
469668
def setup(self):
470669
self.y_pred = np.asarray(

0 commit comments

Comments
 (0)