-
When I use jax and absl logging on my programs, my logs get filled with tpu_driver related logs every time a new program is run. For example from absl import app, logging
import jax
def main(argv):
devices = jax.devices()
logging.info(devices)
if __name__ == '__main__':
app.run(main) will output
which is a bit annoying considering that it appear every time a program uses jax. Is there a way to disable it? I still want to get info level logs. |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 1 reply
-
From the code there's a flag to specify the platforms. flags.DEFINE_string(
'jax_platforms',
os.getenv('JAX_PLATFORMS', '').lower(),
'Comma-separated list of platform names specifying which platforms jax '
'should attempt to initialize. The first platform in the list that is '
'successfully initialized will be used as the default platform. For '
'example, --jax_platforms=cpu,gpu means that CPU and GPU backends will be '
'initialized, and the CPU backend will be used unless otherwise specified; '
'--jax_platforms=cpu means that only the CPU backend will be initialized. '
'By default, jax will try to initialize all available platforms and will '
'default to GPU or TPU if available, and fallback to CPU otherwise.') So to remove that info logging, adding the env I'm leaving this here in case someone is also bothered by that log :) |
Beta Was this translation helpful? Give feedback.
-
You can suppress all xla_bridge messages using: import logging
import jax
logging.getLogger('jax._src.lib.xla_bridge').addFilter(lambda _: False) |
Beta Was this translation helpful? Give feedback.
-
try this:
|
Beta Was this translation helpful? Give feedback.
From the code there's a flag to specify the platforms.