Skip to content

Commit 6c2fc07

Browse files
committed
Implement all crossentropy losses.
1 parent 5b72930 commit 6c2fc07

File tree

6 files changed

+1075
-64
lines changed

6 files changed

+1075
-64
lines changed

keras_core/backend/jax/nn.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from jax import lax
55
from jax import nn as jnn
66

7+
from keras_core.backend.config import epsilon
8+
79

810
def relu(x):
911
return jnn.relu(x)
@@ -356,3 +358,76 @@ def conv_transpose(
356358

357359
def one_hot(x, num_classes, axis=-1):
358360
return jnn.one_hot(x, num_classes, axis=axis)
361+
362+
363+
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
364+
target = jnp.array(target)
365+
output = jnp.array(output)
366+
367+
if target.shape != output.shape:
368+
raise ValueError(
369+
"Arguments `target` and `output` must have the same shape. "
370+
"Received: "
371+
f"target.shape={target.shape}, output.shape={output.shape}"
372+
)
373+
if len(target.shape) < 1:
374+
raise ValueError(
375+
"Arguments `target` and `output` must be at least rank 1. "
376+
"Received: "
377+
f"target.shape={target.shape}, output.shape={output.shape}"
378+
)
379+
if from_logits:
380+
log_prob = jax.nn.log_softmax(output, axis=axis)
381+
else:
382+
output = output / jnp.sum(output, axis, keepdims=True)
383+
output = jnp.clip(output, epsilon(), 1.0 - epsilon())
384+
log_prob = jnp.log(output)
385+
return -jnp.sum(target * log_prob, axis=axis)
386+
387+
388+
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
389+
target = jnp.array(target, dtype="int64")
390+
output = jnp.array(output)
391+
if len(target.shape) == len(output.shape) and target.shape[-1] == 1:
392+
target = jnp.squeeze(target, axis=-1)
393+
394+
if len(output.shape) < 1:
395+
raise ValueError(
396+
"Argument `output` must be at least rank 1. "
397+
"Received: "
398+
f"output.shape={output.shape}"
399+
)
400+
if target.shape != output.shape[:-1]:
401+
raise ValueError(
402+
"Arguments `target` and `output` must have the same shape "
403+
"up until the last dimension: "
404+
f"target.shape={target.shape}, output.shape={output.shape}"
405+
)
406+
if from_logits:
407+
log_prob = jax.nn.log_softmax(output, axis=axis)
408+
else:
409+
output = output / jnp.sum(output, axis, keepdims=True)
410+
output = jnp.clip(output, epsilon(), 1.0 - epsilon())
411+
log_prob = jnp.log(output)
412+
target = jnn.one_hot(target, output.shape[axis], axis=axis)
413+
return -jnp.sum(target * log_prob, axis=axis)
414+
415+
416+
def binary_crossentropy(target, output, from_logits=False):
417+
target = jnp.array(target)
418+
output = jnp.array(output)
419+
420+
if target.shape != output.shape:
421+
raise ValueError(
422+
"Arguments `target` and `output` must have the same shape. "
423+
"Received: "
424+
f"target.shape={target.shape}, output.shape={output.shape}"
425+
)
426+
427+
if from_logits:
428+
output = jnn.sigmoid(output)
429+
430+
output = jnp.clip(output, epsilon(), 1.0 - epsilon())
431+
bce = target * jnp.log(output)
432+
bce += (1.0 - target) * jnp.log(1.0 - output)
433+
return -bce

keras_core/backend/tensorflow/nn.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import warnings
2+
13
import tensorflow as tf
24

35
from keras_core.backend.common.backend_utils import (
46
compute_conv_transpose_output_shape,
57
)
8+
from keras_core.backend.config import epsilon
69

710

811
def relu(x):
@@ -302,3 +305,203 @@ def conv_transpose(
302305

303306
def one_hot(x, num_classes, axis=-1):
304307
return tf.one_hot(x, num_classes, axis=axis)
308+
309+
310+
def _get_logits(output, from_logits, op_type, fn_name):
311+
"""Retrieves logits tensor from maybe-softmax or maybe-sigmoid tensor."""
312+
output_ = output
313+
from_logits_ = from_logits
314+
315+
has_keras_logits = hasattr(output, "_keras_logits")
316+
if has_keras_logits:
317+
output_ = output._keras_logits
318+
from_logits_ = True
319+
320+
from_expected_op_type = (
321+
not isinstance(output, (tf.__internal__.EagerTensor, tf.Variable))
322+
and output.op.type == op_type
323+
) and not has_keras_logits
324+
325+
if from_expected_op_type:
326+
# When softmax activation function is used for output operation, we
327+
# use logits from the softmax function directly to compute loss in order
328+
# to prevent collapsing zero when training.
329+
assert len(output.op.inputs) == 1
330+
output_ = output.op.inputs[0]
331+
from_logits_ = True
332+
333+
if from_logits and (has_keras_logits or from_expected_op_type):
334+
warnings.warn(
335+
f'"`{fn_name}` received `from_logits=True`, but '
336+
f"the `output` argument was produced by a {op_type} "
337+
"activation and thus does not represent logits. "
338+
"Was this intended?",
339+
stacklevel=2,
340+
)
341+
return output_, from_logits_
342+
343+
344+
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
345+
"""Categorical crossentropy between an output tensor and a target tensor.
346+
347+
Args:
348+
target: A tensor of the same shape as `output`.
349+
output: A tensor resulting from a softmax
350+
(unless `from_logits` is `True`, in which
351+
case `output` is expected to be the logits).
352+
from_logits: Boolean, whether `output` is the
353+
result of a softmax, or is a tensor of logits.
354+
axis: Int specifying the channels axis. `axis=-1` corresponds to data
355+
format `channels_last`, and `axis=1` corresponds to data format
356+
`channels_first`.
357+
358+
Returns:
359+
Output tensor.
360+
361+
Example:
362+
363+
>>> a = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 1.], shape=[3,3])
364+
>>> print(a)
365+
tf.Tensor(
366+
[[1. 0. 0.]
367+
[0. 1. 0.]
368+
[0. 0. 1.]], shape=(3, 3), dtype=float32)
369+
>>> b = tf.constant([.9, .05, .05, .05, .89, .06, .05, .01, .94],
370+
... shape=[3, 3])
371+
>>> print(b)
372+
tf.Tensor(
373+
[[0.9 0.05 0.05]
374+
[0.05 0.89 0.06]
375+
[0.05 0.01 0.94]], shape=(3, 3), dtype=float32)
376+
>>> loss = categorical_crossentropy(a, b)
377+
>>> print(np.around(loss, 5))
378+
[0.10536 0.11653 0.06188]
379+
>>> loss = categorical_crossentropy(a, a)
380+
>>> print(np.around(loss, 5))
381+
[0. 0. 0.]
382+
"""
383+
target = tf.convert_to_tensor(target)
384+
output = tf.convert_to_tensor(output)
385+
386+
if target.shape != output.shape:
387+
raise ValueError(
388+
"Arguments `target` and `output` must have the same shape. "
389+
"Received: "
390+
f"target.shape={target.shape}, output.shape={output.shape}"
391+
)
392+
if len(target.shape) < 1:
393+
raise ValueError(
394+
"Arguments `target` and `output` must be at least rank 1. "
395+
"Received: "
396+
f"target.shape={target.shape}, output.shape={output.shape}"
397+
)
398+
399+
output, from_logits = _get_logits(
400+
output, from_logits, "Softmax", "categorical_crossentropy"
401+
)
402+
if from_logits:
403+
return tf.nn.softmax_cross_entropy_with_logits(
404+
labels=target, logits=output, axis=axis
405+
)
406+
407+
# Adjust the predictions so that the probability of
408+
# each class for every sample adds up to 1
409+
# This is needed to ensure that the cross entropy is
410+
# computed correctly.
411+
output = output / tf.reduce_sum(output, axis, keepdims=True)
412+
413+
# Compute cross entropy from probabilities.
414+
output = tf.clip_by_value(output, epsilon(), 1.0 - epsilon())
415+
return -tf.reduce_sum(target * tf.math.log(output), axis)
416+
417+
418+
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
419+
"""Categorical crossentropy with integer targets.
420+
421+
Args:
422+
target: An integer tensor.
423+
output: A tensor resulting from a softmax
424+
(unless `from_logits` is True, in which
425+
case `output` is expected to be the logits).
426+
from_logits: Boolean, whether `output` is the
427+
result of a softmax, or is a tensor of logits.
428+
axis: Int specifying the channels axis. `axis=-1` corresponds to data
429+
format `channels_last`, and `axis=1` corresponds to data format
430+
`channels_first`.
431+
432+
Returns:
433+
Output tensor.
434+
"""
435+
if axis != -1 and axis != len(output.shape) - 1:
436+
raise ValueError(
437+
f"Only axis=-1 is currently supported. Received: axis={axis}"
438+
)
439+
440+
target = tf.convert_to_tensor(target)
441+
target = tf.cast(target, dtype="int64")
442+
output = tf.convert_to_tensor(output)
443+
if len(target.shape) == len(output.shape) and target.shape[-1] == 1:
444+
target = tf.squeeze(target, axis=-1)
445+
446+
if len(output.shape) < 1:
447+
raise ValueError(
448+
"Argument `output` must be at least rank 1. "
449+
"Received: "
450+
f"output.shape={output.shape}"
451+
)
452+
if target.shape != output.shape[:-1]:
453+
raise ValueError(
454+
"Arguments `target` and `output` must have the same shape "
455+
"up until the last dimension: "
456+
f"target.shape={target.shape}, output.shape={output.shape}"
457+
)
458+
459+
output, from_logits = _get_logits(
460+
output, from_logits, "Softmax", "sparse_categorical_crossentropy"
461+
)
462+
if not from_logits:
463+
output = tf.clip_by_value(output, epsilon(), 1 - epsilon())
464+
output = tf.math.log(output)
465+
466+
result = tf.nn.sparse_softmax_cross_entropy_with_logits(
467+
labels=target, logits=output
468+
)
469+
return result
470+
471+
472+
def binary_crossentropy(target, output, from_logits=False):
473+
"""Binary crossentropy between an output tensor and a target tensor.
474+
475+
Args:
476+
target: A tensor with the same shape as `output`.
477+
output: A tensor.
478+
from_logits: Whether `output` is expected to be a logits tensor.
479+
By default, we consider that `output`
480+
encodes a probability distribution.
481+
482+
Returns:
483+
A tensor.
484+
"""
485+
target = tf.convert_to_tensor(target)
486+
output = tf.convert_to_tensor(output)
487+
488+
if target.shape != output.shape:
489+
raise ValueError(
490+
"Arguments `target` and `output` must have the same shape. "
491+
"Received: "
492+
f"target.shape={target.shape}, output.shape={output.shape}"
493+
)
494+
495+
output, from_logits = _get_logits(
496+
output, from_logits, "Sigmoid", "binary_crossentropy"
497+
)
498+
if from_logits:
499+
return tf.nn.sigmoid_cross_entropy_with_logits(
500+
labels=target, logits=output
501+
)
502+
503+
# Compute cross entropy from probabilities.
504+
output = tf.clip_by_value(output, epsilon(), 1.0 - epsilon())
505+
bce = target * tf.math.log(output)
506+
bce += (1 - target) * tf.math.log(1 - output)
507+
return -bce

0 commit comments

Comments
 (0)