Skip to content

Commit 7694caf

Browse files
committed
Merge branch 'dev' into keras-3.9
2 parents a1354dd + f262300 commit 7694caf

File tree

8 files changed

+253
-135
lines changed

8 files changed

+253
-135
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ docsrc/.docs_venvs
1010
docsrc/source/api
1111
docsrc/source/_examples
1212
docsrc/source/contributing.md
13+
examples/checkpoints/
1314
build
1415
docs/
1516

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
repos:
1515
- repo: https://github.com/astral-sh/ruff-pre-commit
1616
# Ruff version.
17-
rev: v0.9.6
17+
rev: v0.9.10
1818
hooks:
1919
# Run the linter.
2020
- id: ruff

bayesflow/diagnostics/plots/z_score_contraction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def z_score_contraction(
121121
ax.scatter(contraction[:, i], z_score[:, i], color=color, alpha=0.5)
122122
ax.set_xlim([-0.05, 1.05])
123123

124-
prettify_subplots(plot_data["axes"], tick_fontsize)
124+
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
125125

126126
# Add labels, titles, and set font sizes
127127
add_titles_and_labels(

bayesflow/networks/transformers/fusion_transformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from bayesflow.types import Tensor
66
from bayesflow.utils import check_lengths_same
7+
from bayesflow.utils.decorators import sanitize_input_shape
78

89
from ..summary_network import SummaryNetwork
910

@@ -151,3 +152,8 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens
151152
summary = self.attention_blocks[-1](keras.ops.expand_dims(template, axis=1), rep, training=training, **kwargs)
152153
summary = self.output_projector(keras.ops.squeeze(summary, axis=1))
153154
return summary
155+
156+
@sanitize_input_shape
157+
def build(self, input_shape):
158+
super().build(input_shape)
159+
self.call(keras.ops.zeros(input_shape))

bayesflow/networks/transformers/set_transformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from bayesflow.types import Tensor
55
from bayesflow.utils import check_lengths_same
6+
from bayesflow.utils.decorators import sanitize_input_shape
67

78
from ..summary_network import SummaryNetwork
89

@@ -150,3 +151,8 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
150151
summary = self.pooling_by_attention(summary, training=training, **kwargs)
151152
summary = self.output_projector(summary)
152153
return summary
154+
155+
@sanitize_input_shape
156+
def build(self, input_shape):
157+
super().build(input_shape)
158+
self.call(keras.ops.zeros(input_shape))

bayesflow/networks/transformers/time_series_transformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from bayesflow.types import Tensor
55
from bayesflow.utils import check_lengths_same
6+
from bayesflow.utils.decorators import sanitize_input_shape
67

78
from ..embeddings import Time2Vec, RecurrentEmbedding
89
from ..summary_network import SummaryNetwork
@@ -147,3 +148,8 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens
147148
summary = self.pooling(inp)
148149
summary = self.output_projector(summary)
149150
return summary
151+
152+
@sanitize_input_shape
153+
def build(self, input_shape):
154+
super().build(input_shape)
155+
self.call(keras.ops.zeros(input_shape))

bayesflow/utils/plot_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ def add_y_labels(axes: np.ndarray, num_row: int = None, ylabel: Sequence[str] |
203203

204204

205205
def add_titles(axes: np.ndarray, title: Sequence[str] | str = None, title_fontsize: int = None):
206-
for i, ax in enumerate(axes.flat):
207-
ax.set_title(title[i], fontsize=title_fontsize)
206+
for t, ax in zip(title, axes.flat):
207+
ax.set_title(t, fontsize=title_fontsize)
208208

209209

210210
def add_titles_and_labels(

examples/Linear_Regression_Starter.ipynb

Lines changed: 230 additions & 131 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)