Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/as_set.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from keras.saving import register_keras_serializable as serializable
import numpy as np

from .elementwise_transform import ElementwiseTransform

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/sqrt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from keras.saving import register_keras_serializable as serializable
import numpy as np

from .elementwise_transform import ElementwiseTransform

Expand Down
7 changes: 4 additions & 3 deletions bayesflow/links/ordered.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(self, axis: int, anchor_index: int, **kwargs):
super().__init__(**keras_kwargs(kwargs))
self.axis = axis
self.anchor_index = anchor_index
self.group_indices = None

self.config = {"axis": axis, "anchor_index": anchor_index, **kwargs}

Expand All @@ -22,9 +23,9 @@ def get_config(self):
def build(self, input_shape):
super().build(input_shape)

assert self.anchor_index % input_shape[self.axis] != 0 and self.anchor_index != -1, (
"anchor should not be first or last index."
)
if self.anchor_index % input_shape[self.axis] == 0 or self.anchor_index == -1:
raise RuntimeError("Anchor should not be first or last index.")

self.group_indices = dict(
below=list(range(0, self.anchor_index)),
above=list(range(self.anchor_index + 1, input_shape[self.axis])),
Expand Down
21 changes: 11 additions & 10 deletions bayesflow/links/ordered_quantiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,17 @@ def build(self, input_shape):
else:
# choose quantile level closest to median as anchor index
self.anchor_index = keras.ops.argmin(keras.ops.abs(keras.ops.convert_to_tensor(self.q) - 0.5))
msg = (
"Length of `q` does not coincide with input shape: "
f"len(q)={len(self.q)}, position {self.axis} of shape={input_shape}"
)
assert num_quantile_levels == len(self.q), msg

msg = (
"The link function `OrderedQuantiles` expects at least 3 quantile levels,"
f" but only {num_quantile_levels} were given."
)
assert self.anchor_index not in (0, -1, num_quantile_levels - 1), msg
if len(self.q) != num_quantile_levels:
raise RuntimeError(
f"Length of `q` does not coincide with input shape: len(q)={len(self.q)}, "
f"position {self.axis} of shape={input_shape}"
)

if self.anchor_index in [0, -1, num_quantile_levels - 1]:
raise RuntimeError(
f"The link function `OrderedQuantiles` expects at least 3 quantile levels, "
f"but only {num_quantile_levels} were given."
)

super().build(input_shape)
2 changes: 2 additions & 0 deletions bayesflow/links/positive_semi_definite.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.utils import keras_kwargs


@serializable(package="bayesflow.links")
class PositiveSemiDefinite(keras.Layer):
"""Activation function to link from any square matrix to a positive semidefinite matrix."""

Expand Down
5 changes: 4 additions & 1 deletion bayesflow/networks/consistency_models/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,10 @@ def build(self, xz_shape, conditions_shape=None):
# First, we calculate all unique numbers of discretization steps n
# in a loop, as self.total_steps might be large
self.max_n = int(self._schedule_discretization(self.total_steps))
assert self.max_n == self.s1 + 1

if self.max_n != self.s1 + 1:
raise ValueError("The maximum number of discretization steps must be equal to s1 + 1.")

unique_n = set()
for step in range(int(self.total_steps)):
unique_n.add(int(self._schedule_discretization(step)))
Expand Down
5 changes: 4 additions & 1 deletion bayesflow/networks/embeddings/fourier_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def __init__(
"""

super().__init__(**kwargs)
assert embed_dim % 2 == 0, f"Embedding dimension must be even, but is {embed_dim}."

if embed_dim % 2 != 0:
raise ValueError(f"Embedding dimension must be even, but is {embed_dim}.")

self.w = self.add_weight(initializer=initializer, shape=(embed_dim // 2,), trainable=trainable)
self.scale = scale
self.embed_dim = embed_dim
Expand Down
3 changes: 3 additions & 0 deletions bayesflow/scores/mean_score.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from keras.saving import register_keras_serializable as serializable

from .normed_difference_score import NormedDifferenceScore


@serializable(package="bayesflow.scores")
class MeanScore(NormedDifferenceScore):
r""":math:`S(\hat \theta, \theta) = | \hat \theta - \theta |^2`
Expand Down
3 changes: 3 additions & 0 deletions bayesflow/scores/median_score.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from keras.saving import register_keras_serializable as serializable

from .normed_difference_score import NormedDifferenceScore


@serializable(package="bayesflow.scores")
class MedianScore(NormedDifferenceScore):
r""":math:`S(\hat \theta, \theta) = | \hat \theta - \theta |`
Expand Down
13 changes: 10 additions & 3 deletions bayesflow/scores/multivariate_normal_score.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor
from bayesflow.links import PositiveSemiDefinite
Expand All @@ -9,6 +10,7 @@
from .parametric_distribution_score import ParametricDistributionScore


@serializable(package="bayesflow.scores")
class MultivariateNormalScore(ParametricDistributionScore):
r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = \log( \mathcal N (\theta; \mu, \Sigma))`
Expand Down Expand Up @@ -96,9 +98,14 @@ def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor
A tensor of shape (batch_size, num_samples, D) containing the generated samples.
"""
batch_size, num_samples = batch_shape
dim = mean.shape[-1]
assert mean.shape == (batch_size, dim), "mean must have shape (batch_size, D)"
assert covariance.shape == (batch_size, dim, dim), "covariance must have shape (batch_size, D, D)"
dim = keras.ops.shape(mean)[-1]
if keras.ops.shape(mean) != (batch_size, dim):
raise ValueError(f"mean must have shape (batch_size, {dim}), but got {keras.ops.shape(mean)}")

if keras.ops.shape(covariance) != (batch_size, dim, dim):
raise ValueError(
f"covariance must have shape (batch_size, {dim}, {dim}), but got {keras.ops.shape(covariance)}"
)

# Use Cholesky decomposition to generate samples
cholesky_factor = keras.ops.cholesky(covariance)
Expand Down
2 changes: 2 additions & 0 deletions bayesflow/scores/normed_difference_score.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor

from .scoring_rule import ScoringRule


@serializable(package="bayesflow.scores")
class NormedDifferenceScore(ScoringRule):
r""":math:`S(\hat \theta, \theta; k) = | \hat \theta - \theta |^k`
Expand Down
3 changes: 3 additions & 0 deletions bayesflow/scores/parametric_distribution_score.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor

from .scoring_rule import ScoringRule


@serializable(package="bayesflow.scores")
class ParametricDistributionScore(ScoringRule):
r""":math:`S(\hat p_\phi, \theta; k) = \log(\hat p_\phi(\theta))`
Expand Down
2 changes: 2 additions & 0 deletions bayesflow/scores/quantile_score.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Sequence

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor
from bayesflow.utils import logging
Expand All @@ -9,6 +10,7 @@
from .scoring_rule import ScoringRule


@serializable(package="bayesflow.scores")
class QuantileScore(ScoringRule):
r""":math:`S(\hat \theta_i, \theta; \tau_i)
= (\hat \theta_i - \theta)(\mathbf{1}_{\hat \theta - \theta > 0} - \tau_i)`
Expand Down
2 changes: 2 additions & 0 deletions bayesflow/scores/scoring_rule.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import math

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor
from bayesflow.utils import find_network, serialize_value_or_type, deserialize_value_or_type


@serializable(package="bayesflow.scores")
class ScoringRule:
"""Base class for scoring rules.
Expand Down
4 changes: 3 additions & 1 deletion bayesflow/utils/_docs/_populate_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

def _add_imports_to_all(include_modules: bool | list[str] = False, exclude: list[str] | None = None):
"""Add all global variables to __all__"""
assert type(include_modules) in [bool, list]
if not isinstance(include_modules, (bool, list)):
raise ValueError("include_modules must be a boolean or a list of strings")

exclude = exclude or []
calling_module = inspect.stack()[1]
local_stack = calling_module[0]
Expand Down
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
dependencies:
- jupyter
- jupyterlab
- keras ~= 3.7.0
- keras ~= 3.9.0
- numpy ~= 1.26
- matplotlib
- pre-commit
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ license = { file = "LICENSE" }

requires-python = ">= 3.10, < 3.12"
dependencies = [
"keras ~= 3.7.0",
"keras ~= 3.9.0",
"numpy ~= 1.26.4",
"scipy ~= 1.14.1",
"matplotlib",
Expand Down