Skip to content

Commit ef6a32a

Browse files
committed
Add small doc improvements
1 parent 9767f6a commit ef6a32a

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

bayesflow/approximators/approximator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def _batch_size_from_data(self, data: any):
142142
"""Obtain the batch size from a batch of data.
143143
144144
To properly weight the metrics for batches of different sizes, the batch size of a given batch of data is
145-
required. As the data structure differs between approximators, each approximator has to specify this method.
145+
required. As the data structure differs between approximators, each concrete approximator has to specify
146+
this method.
146147
147148
Parameters
148149
----------

bayesflow/approximators/continuous_approximator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,5 +492,9 @@ def _log_prob(
492492
**filter_kwargs(kwargs, self.inference_network.log_prob),
493493
)
494494

495-
def _batch_size_from_data(self, data: Mapping[str, any]):
495+
def _batch_size_from_data(self, data: Mapping[str, any]) -> int:
496+
"""
497+
Fetches the current batch size from an input dictionary. Can only be used during training when
498+
inference variables as present.
499+
"""
496500
return keras.ops.shape(data["inference_variables"])[0]

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,5 +379,9 @@ def summaries(self, data: Mapping[str, np.ndarray], **kwargs):
379379
summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
380380
return summaries
381381

382-
def _batch_size_from_data(self, data: Mapping[str, any]):
382+
def _batch_size_from_data(self, data: Mapping[str, any]) -> int:
383+
"""
384+
Fetches the current batch size from an input dictionary. Can only be used during training when
385+
model indices as present.
386+
"""
383387
return keras.ops.shape(data["model_indices"])[0]

0 commit comments

Comments
 (0)