Skip to content

Commit 0be86da

Browse files
committed
Revert unrelated code, remove deprecated code from a test
1 parent 6fbdb17 commit 0be86da

File tree

4 files changed

+9
-17
lines changed

4 files changed

+9
-17
lines changed

tests/link/jax/test_scalar.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,10 @@
3838

3939

4040
try:
41-
import tensorflow_probability # noqa: F401
42-
from jax.interpreters.xla import (
43-
pytype_aval_mappings, # This is what's missing in new JAX # noqa: F401
44-
)
41+
pass
4542

4643
TFP_INSTALLED = True
47-
except (ModuleNotFoundError, AttributeError, ImportError):
44+
except ModuleNotFoundError:
4845
TFP_INSTALLED = False
4946

5047

@@ -163,7 +160,6 @@ def test_tfp_ops(op, test_values):
163160
compare_jax_and_py(inputs, [output], test_values)
164161

165162

166-
@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability")
167163
def test_betaincinv():
168164
a = vector("a", dtype="float64")
169165
b = vector("b", dtype="float64")
@@ -181,7 +177,6 @@ def test_betaincinv():
181177
)
182178

183179

184-
@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability")
185180
def test_gammaincinv():
186181
k = vector("k", dtype="float64")
187182
x = vector("x", dtype="float64")
@@ -190,7 +185,6 @@ def test_gammaincinv():
190185
compare_jax_and_py([k, x], [out], [np.array([5.5, 7.0]), np.array([0.25, 0.7])])
191186

192187

193-
@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability")
194188
def test_gammainccinv():
195189
k = vector("k", dtype="float64")
196190
x = vector("x", dtype="float64")

tests/link/jax/test_tensor_basic.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,6 @@ def test_tri_nonconcrete():
226226

227227
out = ptb.tri(m, n, k)
228228

229-
# The actual error the user will see should be jax.errors.ConcretizationTypeError, but
230-
# the error handler raises an Attribute error first, so that's what this test needs to pass
231229
with pytest.raises(
232230
NotImplementedError,
233231
match=re.escape(

tests/tensor/rewriting/test_elemwise.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,9 +1642,9 @@ def test_InplaceElemwiseOptimizer_bug():
16421642
# with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10):
16431643
rewrite_graph(fgraph, include=("inplace",))
16441644

1645-
with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=1):
1646-
with pytest.warns(
1647-
FutureWarning,
1648-
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
1649-
):
1650-
rewrite_graph(fgraph, include=("inplace",))
1645+
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1
1646+
with pytest.warns(
1647+
FutureWarning,
1648+
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
1649+
):
1650+
rewrite_graph(fgraph, include=("inplace",))

tests/tensor/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def test_masked_array_not_implemented(
704704

705705

706706
def check_alloc_runtime_broadcast(mode):
707-
"""Check we emit a clear error when runtime broadcasting would occur according to Numpy rules."""
707+
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
708708
floatX = config.floatX
709709
x_v = vector("x", shape=(None,))
710710

0 commit comments

Comments
 (0)