-
Notifications
You must be signed in to change notification settings - Fork 19.6k
[keras/src/print.py] Implement backend-specialised print
function
#21344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 3 commits
d31c128
4571fee
726837c
6c7db62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -428,3 +428,13 @@ def device_scope(device_name): | |
else: | ||
jax_device = device_name | ||
return jax.default_device(jax_device) | ||
|
||
|
||
def print(*args, **kwargs): | ||
"""Prints values and works in staged out JAX functions. | ||
|
||
This function does *not* work with f-strings because formatting is delayed. | ||
So instead of ``jax.debug.print(f"hello {bar}")``, write | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Single back ticks ` here and on the next line |
||
``jax.debug.print("hello {bar}", bar=bar)``. | ||
""" | ||
return jax.debug.print(*args, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -696,3 +696,14 @@ def __exit__(self, *args, **kwargs): | |
|
||
def device_scope(device_name): | ||
return tf.device(device_name) | ||
|
||
|
||
def print(*args, **kwargs): | ||
"""Print the specified inputs. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The backend specific docstring is not visible anywhere (neither on keras.io, nor in an IDE). So if you think it is useful, move to the docstring of |
||
|
||
A TensorFlow operator that prints the specified inputs to a desired | ||
output stream or logging level. The inputs may be dense or sparse Tensors, | ||
primitive python objects, data structures that contain tensors, and | ||
printable Python objects. Printed tensors will recursively show the first | ||
and last elements of each dimension to summarize.""" | ||
return tf.print(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backend specific docstring is not visible anywhere (neither on keras.io, nor in an IDE). So move to the docstring of
keras/src/ops/core.py
explaining that this is JAX specific.