We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents cb6881d + acae2f0 commit 210bd30Copy full SHA for 210bd30
jax/experimental/jax2tf/tests/jax2tf_test.py
@@ -44,20 +44,8 @@
44
45
import numpy as np
46
import tensorflow as tf
47
-# pylint: disable=g-direct-tensorflow-import
48
-from tensorflow.compiler.tf2xla.python import xla as tfxla
49
-# pylint: enable=g-direct-tensorflow-import
50
51
config.parse_flags_with_absl()
52
-_exit_stack = contextlib.ExitStack()
53
-
54
-# TODO(necula): Remove once tensorflow is 2.10.0 everywhere.
55
-def setUpModule():
56
- if not hasattr(tfxla, "optimization_barrier"):
57
- _exit_stack.enter_context(jtu.global_config_context(jax_remat_opt_barrier=False))
58
59
-def tearDownModule():
60
- _exit_stack.close()
61
62
63
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
0 commit comments