|
| 1 | +import warnings |
| 2 | + |
1 | 3 | import tensorflow as tf |
2 | 4 |
|
3 | 5 | from keras_core.backend.common.backend_utils import ( |
4 | 6 | compute_conv_transpose_output_shape, |
5 | 7 | ) |
| 8 | +from keras_core.backend.config import epsilon |
6 | 9 |
|
7 | 10 |
|
8 | 11 | def relu(x): |
@@ -302,3 +305,203 @@ def conv_transpose( |
302 | 305 |
|
303 | 306 | def one_hot(x, num_classes, axis=-1): |
304 | 307 | return tf.one_hot(x, num_classes, axis=axis) |
| 308 | + |
| 309 | + |
| 310 | +def _get_logits(output, from_logits, op_type, fn_name): |
| 311 | + """Retrieves logits tensor from maybe-softmax or maybe-sigmoid tensor.""" |
| 312 | + output_ = output |
| 313 | + from_logits_ = from_logits |
| 314 | + |
| 315 | + has_keras_logits = hasattr(output, "_keras_logits") |
| 316 | + if has_keras_logits: |
| 317 | + output_ = output._keras_logits |
| 318 | + from_logits_ = True |
| 319 | + |
| 320 | + from_expected_op_type = ( |
| 321 | + not isinstance(output, (tf.__internal__.EagerTensor, tf.Variable)) |
| 322 | + and output.op.type == op_type |
| 323 | + ) and not has_keras_logits |
| 324 | + |
| 325 | + if from_expected_op_type: |
| 326 | + # When softmax activation function is used for output operation, we |
| 327 | + # use logits from the softmax function directly to compute loss in order |
| 328 | + # to prevent collapsing zero when training. |
| 329 | + assert len(output.op.inputs) == 1 |
| 330 | + output_ = output.op.inputs[0] |
| 331 | + from_logits_ = True |
| 332 | + |
| 333 | + if from_logits and (has_keras_logits or from_expected_op_type): |
| 334 | + warnings.warn( |
| 335 | + f'"`{fn_name}` received `from_logits=True`, but ' |
| 336 | + f"the `output` argument was produced by a {op_type} " |
| 337 | + "activation and thus does not represent logits. " |
| 338 | + "Was this intended?", |
| 339 | + stacklevel=2, |
| 340 | + ) |
| 341 | + return output_, from_logits_ |
| 342 | + |
| 343 | + |
| 344 | +def categorical_crossentropy(target, output, from_logits=False, axis=-1): |
| 345 | + """Categorical crossentropy between an output tensor and a target tensor. |
| 346 | +
|
| 347 | + Args: |
| 348 | + target: A tensor of the same shape as `output`. |
| 349 | + output: A tensor resulting from a softmax |
| 350 | + (unless `from_logits` is `True`, in which |
| 351 | + case `output` is expected to be the logits). |
| 352 | + from_logits: Boolean, whether `output` is the |
| 353 | + result of a softmax, or is a tensor of logits. |
| 354 | + axis: Int specifying the channels axis. `axis=-1` corresponds to data |
| 355 | + format `channels_last`, and `axis=1` corresponds to data format |
| 356 | + `channels_first`. |
| 357 | +
|
| 358 | + Returns: |
| 359 | + Output tensor. |
| 360 | +
|
| 361 | + Example: |
| 362 | +
|
| 363 | + >>> a = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 1.], shape=[3,3]) |
| 364 | + >>> print(a) |
| 365 | + tf.Tensor( |
| 366 | + [[1. 0. 0.] |
| 367 | + [0. 1. 0.] |
| 368 | + [0. 0. 1.]], shape=(3, 3), dtype=float32) |
| 369 | + >>> b = tf.constant([.9, .05, .05, .05, .89, .06, .05, .01, .94], |
| 370 | + ... shape=[3, 3]) |
| 371 | + >>> print(b) |
| 372 | + tf.Tensor( |
| 373 | + [[0.9 0.05 0.05] |
| 374 | + [0.05 0.89 0.06] |
| 375 | + [0.05 0.01 0.94]], shape=(3, 3), dtype=float32) |
| 376 | + >>> loss = categorical_crossentropy(a, b) |
| 377 | + >>> print(np.around(loss, 5)) |
| 378 | + [0.10536 0.11653 0.06188] |
| 379 | + >>> loss = categorical_crossentropy(a, a) |
| 380 | + >>> print(np.around(loss, 5)) |
| 381 | + [0. 0. 0.] |
| 382 | + """ |
| 383 | + target = tf.convert_to_tensor(target) |
| 384 | + output = tf.convert_to_tensor(output) |
| 385 | + |
| 386 | + if target.shape != output.shape: |
| 387 | + raise ValueError( |
| 388 | + "Arguments `target` and `output` must have the same shape. " |
| 389 | + "Received: " |
| 390 | + f"target.shape={target.shape}, output.shape={output.shape}" |
| 391 | + ) |
| 392 | + if len(target.shape) < 1: |
| 393 | + raise ValueError( |
| 394 | + "Arguments `target` and `output` must be at least rank 1. " |
| 395 | + "Received: " |
| 396 | + f"target.shape={target.shape}, output.shape={output.shape}" |
| 397 | + ) |
| 398 | + |
| 399 | + output, from_logits = _get_logits( |
| 400 | + output, from_logits, "Softmax", "categorical_crossentropy" |
| 401 | + ) |
| 402 | + if from_logits: |
| 403 | + return tf.nn.softmax_cross_entropy_with_logits( |
| 404 | + labels=target, logits=output, axis=axis |
| 405 | + ) |
| 406 | + |
| 407 | + # Adjust the predictions so that the probability of |
| 408 | + # each class for every sample adds up to 1 |
| 409 | + # This is needed to ensure that the cross entropy is |
| 410 | + # computed correctly. |
| 411 | + output = output / tf.reduce_sum(output, axis, keepdims=True) |
| 412 | + |
| 413 | + # Compute cross entropy from probabilities. |
| 414 | + output = tf.clip_by_value(output, epsilon(), 1.0 - epsilon()) |
| 415 | + return -tf.reduce_sum(target * tf.math.log(output), axis) |
| 416 | + |
| 417 | + |
| 418 | +def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): |
| 419 | + """Categorical crossentropy with integer targets. |
| 420 | +
|
| 421 | + Args: |
| 422 | + target: An integer tensor. |
| 423 | + output: A tensor resulting from a softmax |
| 424 | + (unless `from_logits` is True, in which |
| 425 | + case `output` is expected to be the logits). |
| 426 | + from_logits: Boolean, whether `output` is the |
| 427 | + result of a softmax, or is a tensor of logits. |
| 428 | + axis: Int specifying the channels axis. `axis=-1` corresponds to data |
| 429 | + format `channels_last`, and `axis=1` corresponds to data format |
| 430 | + `channels_first`. |
| 431 | +
|
| 432 | + Returns: |
| 433 | + Output tensor. |
| 434 | + """ |
| 435 | + if axis != -1 and axis != len(output.shape) - 1: |
| 436 | + raise ValueError( |
| 437 | + f"Only axis=-1 is currently supported. Received: axis={axis}" |
| 438 | + ) |
| 439 | + |
| 440 | + target = tf.convert_to_tensor(target) |
| 441 | + target = tf.cast(target, dtype="int64") |
| 442 | + output = tf.convert_to_tensor(output) |
| 443 | + if len(target.shape) == len(output.shape) and target.shape[-1] == 1: |
| 444 | + target = tf.squeeze(target, axis=-1) |
| 445 | + |
| 446 | + if len(output.shape) < 1: |
| 447 | + raise ValueError( |
| 448 | + "Argument `output` must be at least rank 1. " |
| 449 | + "Received: " |
| 450 | + f"output.shape={output.shape}" |
| 451 | + ) |
| 452 | + if target.shape != output.shape[:-1]: |
| 453 | + raise ValueError( |
| 454 | + "Arguments `target` and `output` must have the same shape " |
| 455 | + "up until the last dimension: " |
| 456 | + f"target.shape={target.shape}, output.shape={output.shape}" |
| 457 | + ) |
| 458 | + |
| 459 | + output, from_logits = _get_logits( |
| 460 | + output, from_logits, "Softmax", "sparse_categorical_crossentropy" |
| 461 | + ) |
| 462 | + if not from_logits: |
| 463 | + output = tf.clip_by_value(output, epsilon(), 1 - epsilon()) |
| 464 | + output = tf.math.log(output) |
| 465 | + |
| 466 | + result = tf.nn.sparse_softmax_cross_entropy_with_logits( |
| 467 | + labels=target, logits=output |
| 468 | + ) |
| 469 | + return result |
| 470 | + |
| 471 | + |
| 472 | +def binary_crossentropy(target, output, from_logits=False): |
| 473 | + """Binary crossentropy between an output tensor and a target tensor. |
| 474 | +
|
| 475 | + Args: |
| 476 | + target: A tensor with the same shape as `output`. |
| 477 | + output: A tensor. |
| 478 | + from_logits: Whether `output` is expected to be a logits tensor. |
| 479 | + By default, we consider that `output` |
| 480 | + encodes a probability distribution. |
| 481 | +
|
| 482 | + Returns: |
| 483 | + A tensor. |
| 484 | + """ |
| 485 | + target = tf.convert_to_tensor(target) |
| 486 | + output = tf.convert_to_tensor(output) |
| 487 | + |
| 488 | + if target.shape != output.shape: |
| 489 | + raise ValueError( |
| 490 | + "Arguments `target` and `output` must have the same shape. " |
| 491 | + "Received: " |
| 492 | + f"target.shape={target.shape}, output.shape={output.shape}" |
| 493 | + ) |
| 494 | + |
| 495 | + output, from_logits = _get_logits( |
| 496 | + output, from_logits, "Sigmoid", "binary_crossentropy" |
| 497 | + ) |
| 498 | + if from_logits: |
| 499 | + return tf.nn.sigmoid_cross_entropy_with_logits( |
| 500 | + labels=target, logits=output |
| 501 | + ) |
| 502 | + |
| 503 | + # Compute cross entropy from probabilities. |
| 504 | + output = tf.clip_by_value(output, epsilon(), 1.0 - epsilon()) |
| 505 | + bce = target * tf.math.log(output) |
| 506 | + bce += (1 - target) * tf.math.log(1 - output) |
| 507 | + return -bce |
0 commit comments