Skip to content

Commit 16491be

Browse files
authored
Fix Optimal Transport for Compiled Contexts (#446)
* remove the is_symbolic_tensor check because this would otherwise skip the whole function for compiled contexts * skip pyabc test * fix sinkhorn and log_sinkhorn message formatting for jax by making the warning message worse
1 parent de8e1cb commit 16491be

File tree

3 files changed

+11
-26
lines changed

3 files changed

+11
-26
lines changed

bayesflow/utils/optimal_transport/log_sinkhorn.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import keras
22

33
from .. import logging
4-
from ..tensor_utils import is_symbolic_tensor
54

65
from .euclidean import euclidean
76

@@ -27,9 +26,6 @@ def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8,
2726

2827
log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16)
2928

30-
if is_symbolic_tensor(log_plan):
31-
return log_plan
32-
3329
def contains_nans(plan):
3430
return keras.ops.any(keras.ops.isnan(plan))
3531

@@ -57,22 +53,18 @@ def do_nothing():
5753
pass
5854

5955
def log_steps():
60-
msg = "Log-Sinkhorn-Knopp converged after {:d} steps."
56+
msg = "Log-Sinkhorn-Knopp converged after {} steps."
6157

6258
logging.debug(msg, steps)
6359

6460
def warn_convergence():
65-
marginals = keras.ops.logsumexp(log_plan, axis=0)
66-
deviations = keras.ops.abs(marginals)
67-
badness = 100.0 * keras.ops.exp(keras.ops.max(deviations))
68-
69-
msg = "Log-Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)."
61+
msg = "Log-Sinkhorn-Knopp did not converge after {} steps."
7062

71-
logging.warning(msg, max_steps, badness)
63+
logging.warning(msg, max_steps)
7264

7365
def warn_nans():
74-
msg = "Log-Sinkhorn-Knopp produced NaNs."
75-
logging.warning(msg)
66+
msg = "Log-Sinkhorn-Knopp produced NaNs after {} steps."
67+
logging.warning(msg, steps)
7668

7769
keras.ops.cond(contains_nans(log_plan), warn_nans, do_nothing)
7870
keras.ops.cond(is_converged(log_plan), log_steps, warn_convergence)

bayesflow/utils/optimal_transport/sinkhorn.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from bayesflow.types import Tensor
44

55
from .. import logging
6-
from ..tensor_utils import is_symbolic_tensor
76

87
from .euclidean import euclidean
98

@@ -76,9 +75,6 @@ def sinkhorn_plan(
7675
# initialize the transport plan from a gaussian kernel
7776
plan = keras.ops.exp(cost / -(regularization * keras.ops.mean(cost) + 1e-16))
7877

79-
if is_symbolic_tensor(plan):
80-
return plan
81-
8278
def contains_nans(plan):
8379
return keras.ops.any(keras.ops.isnan(plan))
8480

@@ -106,22 +102,18 @@ def do_nothing():
106102
pass
107103

108104
def log_steps():
109-
msg = "Sinkhorn-Knopp converged after {:d} steps."
105+
msg = "Sinkhorn-Knopp converged after {} steps."
110106

111107
logging.info(msg, max_steps)
112108

113109
def warn_convergence():
114-
marginals = keras.ops.sum(plan, axis=0)
115-
deviations = keras.ops.abs(marginals - 1.0)
116-
badness = 100.0 * keras.ops.max(deviations)
117-
118-
msg = "Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)."
110+
msg = "Sinkhorn-Knopp did not converge after {}."
119111

120-
logging.warning(msg, max_steps, badness)
112+
logging.warning(msg, max_steps)
121113

122114
def warn_nans():
123-
msg = "Sinkhorn-Knopp produced NaNs."
124-
logging.warning(msg)
115+
msg = "Sinkhorn-Knopp produced NaNs after {} steps."
116+
logging.warning(msg, steps)
125117

126118
keras.ops.cond(contains_nans(plan), warn_nans, do_nothing)
127119
keras.ops.cond(is_converged(plan), log_steps, warn_convergence)

tests/test_examples/test_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def test_bayesian_experimental_design(examples_path):
99
run_notebook(examples_path / "Bayesian_Experimental_Design.ipynb")
1010

1111

12+
@pytest.mark.skip(reason="requires setting up pyabc")
1213
@pytest.mark.slow
1314
def test_from_abc_to_bayesflow(examples_path):
1415
run_notebook(examples_path / "From_ABC_to_BayesFlow.ipynb")

0 commit comments

Comments
 (0)