Skip to content

Commit 4c34065

Browse files
committed
Merge branch 'dev' into c2st
2 parents 00ff927 + 22c75d1 commit 4c34065

File tree

14 files changed

+67
-45
lines changed

14 files changed

+67
-45
lines changed

.github/workflows/multiversion-docs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
run: |
4444
cd ./repo/docsrc
4545
make clean
46-
make docs-sequential
46+
make production-docs-sequential
4747
4848
- name: Checkout gh-pages-dev
4949
uses: actions/checkout@v3

.github/workflows/test-docs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
run: |
3232
cd ./docsrc
3333
make clean
34-
make local
34+
make local-docs
3535
- name: Clean up
3636
run: |
3737
cd ./docsrc

CONTRIBUTING.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,17 @@ You can re-build the for your local state with:
171171

172172
```bash
173173
cd docsrc
174-
make clean && make local
174+
make clean && make local-docs
175175
# in case of issues, try `make clean-all`
176176
```
177177

178178
Note that files ignored by git (i.e., listed in `.gitignore`) are not included in the documentation.
179179

180-
We also provide a multi-version documentation. To generate it, run
180+
We also provide a multi-version documentation, which renders the branches `main` and `stable-legacy`. To generate it, run
181181

182182
```bash
183183
cd docsrc
184-
make clean && make docs
184+
make clean && make production-docs
185185
```
186186

187187
This will create and cache virtual environments for the build at `docsrc/.docs_venvs`.

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,16 @@ More tutorials are always welcome! Please consider making a pull request if you
6363

6464
## Install
6565

66-
BayesFlow is available to install via pip:
66+
BayesFlow v2 is not yet installable via PyPI, but you can use the following command to install the latest version of the `main` branch:
6767

6868
```bash
69-
pip install bayesflow
69+
pip install git+https://github.com/bayesflow-org/bayesflow.git
7070
```
7171

72+
If you encounter problems with this or require more control, please refer to the instructions to install from source below.
73+
74+
Note: `pip install bayesflow` will install the v1 version of BayesFlow.
75+
7276
### Backend
7377

7478
To use BayesFlow, you will also need to install one of the following machine learning backends.

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99
from bayesflow.types import Tensor
10-
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type
10+
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type, weighted_sum
1111

1212

1313
from ..inference_network import InferenceNetwork
@@ -285,7 +285,9 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None,
285285
out = skip * x + out * f
286286
return out
287287

288-
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
288+
def compute_metrics(
289+
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
290+
) -> dict[str, Tensor]:
289291
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
290292

291293
# The discretization schedule requires the number of passed training steps.
@@ -328,6 +330,7 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr
328330
lam = 1 / (t2 - t1)
329331

330332
# Pseudo-huber loss, see [2], Section 3.3
331-
loss = ops.mean(lam * (ops.sqrt(ops.square(teacher_out - student_out) + self.c_huber2) - self.c_huber))
333+
loss = lam * (ops.sqrt(ops.square(teacher_out - student_out) + self.c_huber2) - self.c_huber)
334+
loss = weighted_sum(loss, sample_weight)
332335

333336
return base_metrics | {"loss": loss}

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22
from keras.saving import register_keras_serializable as serializable
33

44
from bayesflow.types import Tensor
5-
from bayesflow.utils import find_permutation, keras_kwargs, serialize_value_or_type, deserialize_value_or_type
5+
from bayesflow.utils import (
6+
find_permutation,
7+
keras_kwargs,
8+
serialize_value_or_type,
9+
deserialize_value_or_type,
10+
weighted_sum,
11+
)
612

713
from .actnorm import ActNorm
814
from .couplings import DualCoupling
@@ -158,11 +164,9 @@ def _inverse(
158164
def compute_metrics(
159165
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
160166
) -> dict[str, Tensor]:
161-
if sample_weight is not None:
162-
print(sample_weight)
163-
base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage)
167+
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
164168

165169
z, log_density = self(x, conditions=conditions, inverse=False, density=True)
166-
loss = self.aggregate(-log_density, sample_weight)
170+
loss = weighted_sum(-log_density, sample_weight)
167171

168172
return base_metrics | {"loss": loss}

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
optimal_transport,
1414
serialize_value_or_type,
1515
deserialize_value_or_type,
16+
weighted_sum,
1617
)
1718
from ..inference_network import InferenceNetwork
1819

@@ -254,11 +255,11 @@ def compute_metrics(
254255
x = t * x1 + (1 - t) * x0
255256
target_velocity = x1 - x0
256257

257-
base_metrics = super().compute_metrics(x1, conditions, sample_weight, stage)
258+
base_metrics = super().compute_metrics(x1, conditions=conditions, stage=stage)
258259

259260
predicted_velocity = self.velocity(x, time=t, conditions=conditions, training=stage == "training")
260261

261262
loss = self.loss_fn(target_velocity, predicted_velocity)
262-
loss = self.aggregate(loss, sample_weight)
263+
loss = weighted_sum(loss, sample_weight)
263264

264265
return base_metrics | {"loss": loss}

bayesflow/networks/inference_network.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> Tens
4848
_, log_density = self(samples, conditions=conditions, inverse=False, density=True, **kwargs)
4949
return log_density
5050

51-
def compute_metrics(
52-
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
53-
) -> dict[str, Tensor]:
51+
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
5452
if not self.built:
5553
xz_shape = keras.ops.shape(x)
5654
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
@@ -66,10 +64,3 @@ def compute_metrics(
6664
metrics[metric.name] = metric(samples, x)
6765

6866
return metrics
69-
70-
def aggregate(self, losses: Tensor, weights: Tensor = None):
71-
if weights is not None:
72-
weighted = losses * weights
73-
else:
74-
weighted = losses
75-
return keras.ops.mean(weighted)

bayesflow/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
tree_concatenate,
7272
tree_stack,
7373
fill_triangular_matrix,
74+
weighted_sum,
7475
)
7576
from .classification import calibration_curve, confusion_matrix
7677
from .validators import check_lengths_same

bayesflow/utils/tensor_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,26 @@ def pad(x: Tensor, value: float | Tensor, n: int, axis: int, side: str = "both")
140140
raise TypeError(f"Invalid side type {type(side)!r}. Must be str.")
141141

142142

143+
def weighted_sum(elements: Tensor, weights: Tensor = None) -> Tensor:
144+
"""
145+
Compute the (optionally) weighted mean of the input tensor.
146+
147+
Parameters
148+
----------
149+
elements : Tensor
150+
A tensor containing the elements to average.
151+
weights : Tensor, optional
152+
A tensor of the same shape as `elements` representing weights.
153+
If None, the mean is computed without weights.
154+
155+
Returns
156+
-------
157+
Tensor
158+
A scalar tensor representing the (weighted) mean.
159+
"""
160+
return keras.ops.mean(elements * weights if weights is not None else elements)
161+
162+
143163
def searchsorted(sorted_sequence: Tensor, values: Tensor, side: str = "left") -> Tensor:
144164
"""
145165
Find indices where elements should be inserted to maintain order.

0 commit comments

Comments
 (0)