Skip to content

Commit a43fd07

Browse files
committed
Merge branch 'dev' of https://github.com/stefanradev93/BayesFlow into dev
2 parents d24f5a3 + acf1c72 commit a43fd07

File tree

6 files changed

+179
-169
lines changed

6 files changed

+179
-169
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
---
2+
name: Bug report
3+
about: Create a bug report to help us improve BayesFlow
4+
title: "[BUG]"
5+
labels: ''
6+
assignees: ''
7+
8+
---
9+
10+
**Describe the bug**
11+
A clear and concise description of what the bug is.
12+
13+
**To Reproduce**
14+
Minimal steps to reproduce the behavior:
15+
1. Import '...'
16+
2. Create network '....'
17+
3. Call '....'
18+
4. See error
19+
20+
**Expected behavior**
21+
A clear and concise description of what you expected to happen.
22+
23+
**Traceback**
24+
If you encounter an error, please provide a complete traceback to help explain your problem.
25+
26+
**Environment**
27+
- OS: [e.g. Ubuntu]
28+
- Python Version: [e.g. 3.11]
29+
- Backend: [e.g. jax, tensorflow, pytorch]
30+
- BayesFlow Version: [e.g. 2.0.2]
31+
32+
**Additional context**
33+
Add any other context about the problem here.
34+
35+
**Minimality**
36+
- [ ] I verify that my example is minimal, does not rely on third-party packages, and is most likely an issue in BayesFlow.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
---
2+
name: Feature request
3+
about: Suggest a new feature to be implemented in BayesFlow
4+
title: "[FEATURE]"
5+
labels: feature
6+
assignees: ''
7+
8+
---
9+
10+
**Is your feature request related to a problem? Please describe.**
11+
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12+
13+
**Describe the solution you'd like**
14+
A clear and concise description of what you want to happen.
15+
16+
**Describe alternatives you've considered**
17+
A clear and concise description of any alternative solutions or features you've considered.
18+
19+
**Additional context**
20+
Add any other context or screenshots about the feature request here.

bayesflow/utils/optimal_transport/log_sinkhorn.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import keras
22

33
from .. import logging
4-
from ..tensor_utils import is_symbolic_tensor
54

65
from .euclidean import euclidean
76

@@ -27,9 +26,6 @@ def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8,
2726

2827
log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16)
2928

30-
if is_symbolic_tensor(log_plan):
31-
return log_plan
32-
3329
def contains_nans(plan):
3430
return keras.ops.any(keras.ops.isnan(plan))
3531

@@ -57,22 +53,18 @@ def do_nothing():
5753
pass
5854

5955
def log_steps():
60-
msg = "Log-Sinkhorn-Knopp converged after {:d} steps."
56+
msg = "Log-Sinkhorn-Knopp converged after {} steps."
6157

6258
logging.debug(msg, steps)
6359

6460
def warn_convergence():
65-
marginals = keras.ops.logsumexp(log_plan, axis=0)
66-
deviations = keras.ops.abs(marginals)
67-
badness = 100.0 * keras.ops.exp(keras.ops.max(deviations))
68-
69-
msg = "Log-Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)."
61+
msg = "Log-Sinkhorn-Knopp did not converge after {} steps."
7062

71-
logging.warning(msg, max_steps, badness)
63+
logging.warning(msg, max_steps)
7264

7365
def warn_nans():
74-
msg = "Log-Sinkhorn-Knopp produced NaNs."
75-
logging.warning(msg)
66+
msg = "Log-Sinkhorn-Knopp produced NaNs after {} steps."
67+
logging.warning(msg, steps)
7668

7769
keras.ops.cond(contains_nans(log_plan), warn_nans, do_nothing)
7870
keras.ops.cond(is_converged(log_plan), log_steps, warn_convergence)

bayesflow/utils/optimal_transport/sinkhorn.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from bayesflow.types import Tensor
44

55
from .. import logging
6-
from ..tensor_utils import is_symbolic_tensor
76

87
from .euclidean import euclidean
98

@@ -76,9 +75,6 @@ def sinkhorn_plan(
7675
# initialize the transport plan from a gaussian kernel
7776
plan = keras.ops.exp(cost / -(regularization * keras.ops.mean(cost) + 1e-16))
7877

79-
if is_symbolic_tensor(plan):
80-
return plan
81-
8278
def contains_nans(plan):
8379
return keras.ops.any(keras.ops.isnan(plan))
8480

@@ -106,22 +102,18 @@ def do_nothing():
106102
pass
107103

108104
def log_steps():
109-
msg = "Sinkhorn-Knopp converged after {:d} steps."
105+
msg = "Sinkhorn-Knopp converged after {} steps."
110106

111107
logging.info(msg, max_steps)
112108

113109
def warn_convergence():
114-
marginals = keras.ops.sum(plan, axis=0)
115-
deviations = keras.ops.abs(marginals - 1.0)
116-
badness = 100.0 * keras.ops.max(deviations)
117-
118-
msg = "Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)."
110+
msg = "Sinkhorn-Knopp did not converge after {}."
119111

120-
logging.warning(msg, max_steps, badness)
112+
logging.warning(msg, max_steps)
121113

122114
def warn_nans():
123-
msg = "Sinkhorn-Knopp produced NaNs."
124-
logging.warning(msg)
115+
msg = "Sinkhorn-Knopp produced NaNs after {} steps."
116+
logging.warning(msg, steps)
125117

126118
keras.ops.cond(contains_nans(plan), warn_nans, do_nothing)
127119
keras.ops.cond(is_converged(plan), log_steps, warn_convergence)

tests/test_examples/test_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def test_bayesian_experimental_design(examples_path):
99
run_notebook(examples_path / "Bayesian_Experimental_Design.ipynb")
1010

1111

12+
@pytest.mark.skip(reason="requires setting up pyabc")
1213
@pytest.mark.slow
1314
def test_from_abc_to_bayesflow(examples_path):
1415
run_notebook(examples_path / "From_ABC_to_BayesFlow.ipynb")

0 commit comments

Comments
 (0)