Skip to content

Commit acae2f0

Browse files
committed
Remove code in jax2tf for compatibility with TF 2.10 or earlier.
1 parent 263d4d1 commit acae2f0

File tree

1 file changed

+0
-12
lines changed

1 file changed

+0
-12
lines changed

jax/experimental/jax2tf/tests/jax2tf_test.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,8 @@
4444

4545
import numpy as np
4646
import tensorflow as tf
47-
# pylint: disable=g-direct-tensorflow-import
48-
from tensorflow.compiler.tf2xla.python import xla as tfxla
49-
# pylint: enable=g-direct-tensorflow-import
5047

5148
config.parse_flags_with_absl()
52-
_exit_stack = contextlib.ExitStack()
53-
54-
# TODO(necula): Remove once tensorflow is 2.10.0 everywhere.
55-
def setUpModule():
56-
if not hasattr(tfxla, "optimization_barrier"):
57-
_exit_stack.enter_context(jtu.global_config_context(jax_remat_opt_barrier=False))
58-
59-
def tearDownModule():
60-
_exit_stack.close()
6149

6250

6351
class Jax2TfTest(tf_test_util.JaxToTfTestCase):

0 commit comments

Comments
 (0)