Skip to content

[keras/src/print.py] Implement backend-specialised print function #21344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras.src.ops.core import fori_loop as fori_loop
from keras.src.ops.core import is_tensor as is_tensor
from keras.src.ops.core import map as map
from keras.src.ops.core import print as print
from keras.src.ops.core import saturate_cast as saturate_cast
from keras.src.ops.core import scan as scan
from keras.src.ops.core import scatter as scatter
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras.src.ops.core import fori_loop as fori_loop
from keras.src.ops.core import is_tensor as is_tensor
from keras.src.ops.core import map as map
from keras.src.ops.core import print as print
from keras.src.ops.core import saturate_cast as saturate_cast
from keras.src.ops.core import scan as scan
from keras.src.ops.core import scatter as scatter
Expand Down
10 changes: 10 additions & 0 deletions keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,13 @@ def device_scope(device_name):
else:
jax_device = device_name
return jax.default_device(jax_device)


def print(*args, **kwargs):
"""Prints values and works in staged out JAX functions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backend specific docstring is not visible anywhere (neither on keras.io, nor in an IDE). So move to the docstring of keras/src/ops/core.py explaining that this is JAX specific.


This function does *not* work with f-strings because formatting is delayed.
So instead of ``jax.debug.print(f"hello {bar}")``, write
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single back ticks ` here and on the next line

``jax.debug.print("hello {bar}", bar=bar)``.
"""
return jax.debug.print(*args, **kwargs)
11 changes: 11 additions & 0 deletions keras/src/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,3 +696,14 @@ def __exit__(self, *args, **kwargs):

def device_scope(device_name):
return tf.device(device_name)


def print(*args, **kwargs):
"""Print the specified inputs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backend specific docstring is not visible anywhere (neither on keras.io, nor in an IDE). So if you think it is useful, move to the docstring of keras/src/ops/core.py explaining that this is TensorFlow specific.


A TensorFlow operator that prints the specified inputs to a desired
output stream or logging level. The inputs may be dense or sparse Tensors,
primitive python objects, data structures that contain tensors, and
printable Python objects. Printed tensors will recursively show the first
and last elements of each dimension to summarize."""
return tf.print(*args, **kwargs)
10 changes: 10 additions & 0 deletions keras/src/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,3 +1183,13 @@ def grad(*args, upstream):
```
"""
return backend.core.custom_gradient(f)


_print = print


@keras_export("keras.ops.print")
def print(*args, **kwargs):
return (backend.core.print if hasattr(backend.core, "print") else _print)(
*args, **kwargs
)