Skip to content

Commit 703b050

Browse files
committed
Merge branch 'dev' into feat-test-non-invertible
2 parents 661f89d + 57c9ad8 commit 703b050

File tree

18 files changed

+728
-1404
lines changed

18 files changed

+728
-1404
lines changed

.github/workflows/publish.yaml

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,53 @@
1+
name: Publish Python 🐍 distribution 📦 to PyPI
12

2-
name: Publish to PyPI.org
33
on:
44
release:
55
types: [published]
6+
67
jobs:
7-
pypi:
8+
build:
9+
name: Build distribution 📦
10+
runs-on: ubuntu-latest
11+
12+
steps:
13+
- uses: actions/checkout@v4
14+
with:
15+
persist-credentials: false
16+
- name: Set up Python
17+
uses: actions/setup-python@v5
18+
with:
19+
python-version: "3.x"
20+
- name: Install pypa/build
21+
run: >-
22+
python3 -m
23+
pip install
24+
build
25+
--user
26+
- name: Build a binary wheel and a source tarball
27+
run: python3 -m build
28+
- name: Store the distribution packages
29+
uses: actions/upload-artifact@v4
30+
with:
31+
name: python-package-distributions
32+
path: dist/
33+
34+
publish-to-pypi:
35+
name: >-
36+
Publish Python 🐍 distribution 📦 to PyPI
37+
needs:
38+
- build
839
runs-on: ubuntu-latest
40+
environment:
41+
name: pypi
42+
url: https://pypi.org/p/bayesflow # Replace <package-name> with your PyPI project name
43+
permissions:
44+
id-token: write # IMPORTANT: mandatory for trusted publishing
45+
946
steps:
10-
- name: Checkout
11-
uses: actions/checkout@v4
12-
with:
13-
fetch-depth: 0
14-
- run: python3 -m pip install -U build && python3 -m build
15-
- name: Publish package
16-
uses: pypa/gh-action-pypi-publish@release/v1
17-
with:
18-
password: ${{ secrets.PYPI_API_TOKEN }}
47+
- name: Download all the dists
48+
uses: actions/download-artifact@v4
49+
with:
50+
name: python-package-distributions
51+
path: dist/
52+
- name: Publish distribution 📦 to PyPI
53+
uses: pypa/gh-action-pypi-publish@release/v1

README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,19 @@ More tutorials are always welcome! Please consider making a pull request if you
6464

6565
## Install
6666

67-
BayesFlow v2 is not yet installable via PyPI, but you can use the following command to install the latest version of the `main` branch:
67+
You can install the latest stable version from PyPI using:
6868

6969
```bash
70-
pip install git+https://github.com/bayesflow-org/bayesflow.git
70+
pip install bayesflow
7171
```
7272

73-
If you encounter problems with this or require more control, please refer to the instructions to install from source below.
73+
If you want the latest features, you can install from source:
7474

75-
Note: `pip install bayesflow` will install the v1 version of BayesFlow.
75+
```bash
76+
pip install git+https://github.com/bayesflow-org/bayesflow.git@dev
77+
```
78+
79+
If you encounter problems with this or require more control, please refer to the instructions to install from source below.
7680

7781
### Backend
7882

bayesflow/diagnostics/plots/loss.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def loss(
1919
figsize: Sequence[float] = None,
2020
train_color: str = "#132a70",
2121
val_color: str = "black",
22+
val_marker: str = "o",
23+
val_marker_size: float = 5,
2224
lw_train: float = 2.0,
2325
lw_val: float = 2.0,
2426
grid_alpha: float = 0.2,
@@ -49,10 +51,14 @@ def loss(
4951
The color for the train loss trajectory
5052
val_color : str, optional, default: None
5153
The color for the optional validation loss trajectory
54+
val_marker: str
55+
Marker style for the validation loss curve. Default is "o".
56+
val_marker_size: float
57+
Marker size for the validation loss curve. Default is 5.
5258
lw_train : int, optional, default: 2
53-
The linewidth for the training loss curve
59+
The line width for the training loss curve
5460
lw_val : int, optional, default: 2
55-
The linewidth for the validation loss curve
61+
The line width for the validation loss curve
5662
grid_alpha : float, optional, default: 0.2
5763
The transparency of the background grid
5864
legend_fontsize : int, optional, default: 14
@@ -130,6 +136,9 @@ def loss(
130136
color=val_color,
131137
lw=lw_val,
132138
alpha=alpha_unsmoothed,
139+
linestyle="--",
140+
marker=val_marker,
141+
markersize=val_marker_size,
133142
label="Validation",
134143
)
135144

@@ -140,6 +149,7 @@ def loss(
140149
val_step_index,
141150
smoothed_val_loss,
142151
color=val_color,
152+
linestyle="--",
143153
lw=lw_val,
144154
alpha=0.8,
145155
label="Validation (Moving Average)",

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
The base probability distribution from which samples are drawn, such as "normal".
8383
Default is "normal".
8484
use_optimal_transport : bool, optional
85-
Whether to apply optimal transport for improved training stability. Default is False.
85+
Whether to apply optimal transport for improved training stability. Default is True.
8686
loss_fn : str, optional
8787
The loss function used for training, such as "mse". Default is "mse".
8888
integrate_kwargs : dict[str, any], optional
@@ -256,9 +256,20 @@ def compute_metrics(
256256
x0 = self.base_distribution.sample(keras.ops.shape(x1)[:-1])
257257

258258
if self.use_optimal_transport:
259-
x1, x0, conditions = optimal_transport(
260-
x1, x0, conditions, seed=self.seed_generator, **self.optimal_transport_kwargs
259+
# we must choose between resampling x0 or x1
260+
# since the data is possibly noisy and may contain outliers, it is better
261+
# to possibly drop some samples from x1 than from x0
262+
# in the marginal over multiple batches, this is not a problem
263+
x0, x1, assignments = optimal_transport(
264+
x0,
265+
x1,
266+
seed=self.seed_generator,
267+
**self.optimal_transport_kwargs,
268+
return_assignments=True,
261269
)
270+
if conditions is not None:
271+
# conditions must be resampled along with x1
272+
conditions = keras.ops.take(conditions, assignments, axis=0)
262273

263274
t = keras.random.uniform((keras.ops.shape(x0)[0],), seed=self.seed_generator)
264275
t = expand_right_as(t, x0)
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
from .optimal_transport import optimal_transport
2-
from .sinkhorn import sinkhorn, sinkhorn_indices, sinkhorn_plan
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import keras
2+
3+
4+
def euclidean(x1, x2):
5+
# TODO: rename and move this function
6+
result = x1[:, None] - x2[None, :]
7+
shape = list(keras.ops.shape(result))
8+
shape[2:] = [-1]
9+
result = keras.ops.reshape(result, shape)
10+
result = keras.ops.norm(result, ord=2, axis=-1)
11+
return result

bayesflow/utils/optimal_transport/hungarian.py

Lines changed: 0 additions & 5 deletions
This file was deleted.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import keras
2+
3+
from .. import logging
4+
from ..tensor_utils import is_symbolic_tensor
5+
6+
from .euclidean import euclidean
7+
8+
9+
def log_sinkhorn(x1, x2, seed: int = None, **kwargs):
10+
"""
11+
Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn`.
12+
Significantly slower than the unstabilized version, so use only when you need numerical stability.
13+
"""
14+
log_plan = log_sinkhorn_plan(x1, x2, **kwargs)
15+
assignments = keras.random.categorical(keras.ops.exp(log_plan), num_samples=1, seed=seed)
16+
assignments = keras.ops.squeeze(assignments, axis=1)
17+
18+
return assignments
19+
20+
21+
def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, max_steps=None):
22+
"""
23+
Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn_plan`.
24+
Significantly slower than the unstabilized version, so use only when you need numerical stability.
25+
"""
26+
cost = euclidean(x1, x2)
27+
28+
log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16)
29+
30+
if is_symbolic_tensor(log_plan):
31+
return log_plan
32+
33+
def contains_nans(plan):
34+
return keras.ops.any(keras.ops.isnan(plan))
35+
36+
def is_converged(plan):
37+
# for convergence, the plan should be doubly stochastic
38+
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), 0.0, rtol=rtol, atol=atol))
39+
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), 0.0, rtol=rtol, atol=atol))
40+
return conv0 & conv1
41+
42+
def cond(_, plan):
43+
# break the while loop if the plan contains nans or is converged
44+
return ~(contains_nans(plan) | is_converged(plan))
45+
46+
def body(steps, plan):
47+
# Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension
48+
plan = keras.ops.log_softmax(plan, axis=0)
49+
plan = keras.ops.log_softmax(plan, axis=1)
50+
51+
return steps + 1, plan
52+
53+
steps = 0
54+
steps, log_plan = keras.ops.while_loop(cond, body, (steps, log_plan), maximum_iterations=max_steps)
55+
56+
def do_nothing():
57+
pass
58+
59+
def log_steps():
60+
msg = "Log-Sinkhorn-Knopp converged after {:d} steps."
61+
62+
logging.info(msg, steps)
63+
64+
def warn_convergence():
65+
marginals = keras.ops.logsumexp(log_plan, axis=0)
66+
deviations = keras.ops.abs(marginals)
67+
badness = 100.0 * keras.ops.exp(keras.ops.max(deviations))
68+
69+
msg = "Log-Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)."
70+
71+
logging.warning(msg, max_steps, badness)
72+
73+
def warn_nans():
74+
msg = "Log-Sinkhorn-Knopp produced NaNs."
75+
logging.warning(msg)
76+
77+
keras.ops.cond(contains_nans(log_plan), warn_nans, do_nothing)
78+
keras.ops.cond(is_converged(log_plan), log_steps, warn_convergence)
79+
80+
return log_plan
Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
from bayesflow.types import Tensor
1+
import keras
22

3-
from .hungarian import hungarian
4-
from .random import random
3+
from .log_sinkhorn import log_sinkhorn
54
from .sinkhorn import sinkhorn
65

6+
methods = {
7+
"sinkhorn": sinkhorn,
8+
"sinkhorn_knopp": sinkhorn,
9+
"log_sinkhorn": log_sinkhorn,
10+
"log_sinkhorn_knopp": log_sinkhorn,
11+
}
712

8-
def optimal_transport(
9-
x1: Tensor, x2: Tensor, *aux: Tensor, method: str = "sinkhorn_knopp", **kwargs
10-
) -> (Tensor, Tensor):
13+
14+
def optimal_transport(x1, x2, method="log_sinkhorn", return_assignments=False, **kwargs):
1115
"""Matches elements from x2 onto x1, such that the transport cost between them is minimized, according to the method
1216
and cost matrix used.
1317
@@ -23,28 +27,21 @@ def optimal_transport(
2327
:param x2: Tensor of shape (m, ...)
2428
Samples from the second distribution.
2529
26-
:param aux: Tensors of shape (n, ...)
27-
Auxiliary tensors to be permuted along with x1.
28-
Note that x2 is never permuted for all currently available methods.
29-
3030
:param method: Method used to compute the transport cost.
31-
Default: 'sinkhorn_knopp'
31+
Default: 'log_sinkhorn'
3232
33-
:param kwargs: Additional keyword arguments passed to the optimization method.
33+
:param return_assignments: Whether to return the assignment indices.
34+
Default: False
35+
36+
:param kwargs: Additional keyword arguments that are passed to the optimization method.
3437
3538
:return: Tensors of shapes (n, ...) and (m, ...)
3639
x1 and x2 in optimal transport permutation order.
3740
"""
38-
methods = {
39-
"hungarian": hungarian,
40-
"sinkhorn": sinkhorn,
41-
"sinkhorn_knopp": sinkhorn,
42-
"random": random,
43-
}
44-
45-
method = method.lower()
41+
assignments = methods[method.lower()](x1, x2, **kwargs)
42+
x2 = keras.ops.take(x2, assignments, axis=0)
4643

47-
if method not in methods:
48-
raise ValueError(f"Unsupported method name: '{method}'.")
44+
if return_assignments:
45+
return x1, x2, assignments
4946

50-
return methods[method](x1, x2, *aux, **kwargs)
47+
return x1, x2

bayesflow/utils/optimal_transport/random.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

0 commit comments

Comments
 (0)