Skip to content

Commit 4877f6b

Browse files
committed
use os.environ to find out whether jax is used
1 parent 1e88ba0 commit 4877f6b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/test_approximators/test_add_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ def call(self, x, training=False, **kwargs):
2323

2424

2525
def test_layer_loss_reported(approximator_using_add_loss, train_dataset, validation_dataset):
26-
from bayesflow.approximators.backend_approximators.jax_approximator import JAXApproximator
26+
import os
2727

28-
if isinstance(approximator_using_add_loss, JAXApproximator):
28+
if os.environ["KERAS_BACKEND"] == "jax":
2929
pytest.skip(reason="With JAX backend, the compute_metrics method currently fails to consider self.losses.")
3030

3131
approximator = approximator_using_add_loss

0 commit comments

Comments
 (0)