Skip to content

Commit 9c2059e

Browse files
committed
Fix typehints to docs.
1 parent ef6a32a commit 9c2059e

File tree

5 files changed

+88
-30
lines changed

5 files changed

+88
-30
lines changed

bayesflow/approximators/approximator.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from collections.abc import Mapping
2-
31
import multiprocessing as mp
42

53
import keras
@@ -22,7 +20,7 @@ def build_adapter(cls, **kwargs) -> Adapter:
2220
# implemented by each respective architecture
2321
raise NotImplementedError
2422

25-
def build_from_data(self, data: Mapping[str, any]) -> None:
23+
def build_from_data(self, data: dict[str, any]) -> None:
2624
self.compute_metrics(**filter_kwargs(data, self.compute_metrics), stage="training")
2725
self.built = True
2826

@@ -137,27 +135,3 @@ def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = N
137135
self.build_from_data(mock_data)
138136

139137
return super().fit(dataset=dataset, **kwargs)
140-
141-
def _batch_size_from_data(self, data: any):
142-
"""Obtain the batch size from a batch of data.
143-
144-
To properly weight the metrics for batches of different sizes, the batch size of a given batch of data is
145-
required. As the data structure differs between approximators, each concrete approximator has to specify
146-
this method.
147-
148-
Parameters
149-
----------
150-
data :
151-
The data that are passed to `compute_metrics` as keyword arguments.
152-
153-
Returns
154-
-------
155-
batch_size : int
156-
The batch size of the given data.
157-
"""
158-
raise NotImplementedError(
159-
"Correct calculation of the metrics requires obtaining the batch size from the supplied data "
160-
"for proper weighting of metrics for batches with different sizes. Please implement the "
161-
"_batch_size_from_data method for your approximator. For a given batch of data, it should "
162-
"return the corresponding batch size."
163-
)

bayesflow/approximators/backend_approximators/jax_approximator.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]:
3131
----------
3232
*args : tuple
3333
Positional arguments passed to the metric computation function.
34-
**kwargs : dict
34+
**kwargs
3535
Keyword arguments passed to the metric computation function.
3636
3737
Returns
@@ -222,3 +222,28 @@ def _update_metrics(self, loss: jax.Array, metrics_variables: any, sample_weight
222222
metrics_variables = [scope.get_current_value(v) for v in self.metrics_variables]
223223

224224
return metrics_variables
225+
226+
# noinspection PyMethodOverriding
227+
def _batch_size_from_data(self, data: any) -> int:
228+
"""Obtain the batch size from a batch of data.
229+
230+
To properly weigh the metrics for batches of different sizes, the batch size of a given batch of data is
231+
required. As the data structure differs between approximators, each concrete approximator has to specify
232+
this method.
233+
234+
Parameters
235+
----------
236+
data :
237+
The data that are passed to `compute_metrics` as keyword arguments.
238+
239+
Returns
240+
-------
241+
batch_size : int
242+
The batch size of the given data.
243+
"""
244+
raise NotImplementedError(
245+
"Correct calculation of the metrics requires obtaining the batch size from the supplied data "
246+
"for proper weighting of metrics for batches with different sizes. Please implement the "
247+
"_batch_size_from_data method for your approximator. For a given batch of data, it should "
248+
"return the corresponding batch size."
249+
)

bayesflow/approximators/backend_approximators/numpy_approximator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,12 @@ def _update_metrics(self, metrics, sample_weight=None):
2727
except ValueError:
2828
self._metrics.append(keras.metrics.Mean(name=name))
2929
self._metrics[-1].update_state(value, sample_weight=sample_weight)
30+
31+
# noinspection PyMethodOverriding
32+
def _batch_size_from_data(self, data: any) -> int:
33+
raise NotImplementedError(
34+
"Correct calculation of the metrics requires obtaining the batch size from the supplied data "
35+
"for proper weighting of metrics for batches with different sizes. Please implement the "
36+
"_batch_size_from_data method for your approximator. For a given batch of data, it should "
37+
"return the corresponding batch size."
38+
)

bayesflow/approximators/backend_approximators/tensorflow_approximator.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def compute_metrics(self, *args, **kwargs) -> dict[str, tf.Tensor]:
3232
----------
3333
*args : tuple
3434
Positional arguments passed to the metric computation function.
35-
**kwargs : dict
35+
**kwargs
3636
Keyword arguments passed to the metric computation function.
3737
3838
Returns
@@ -114,3 +114,28 @@ def _update_metrics(self, metrics: dict[str, any], sample_weight: tf.Tensor = No
114114
except ValueError:
115115
self._metrics.append(keras.metrics.Mean(name=name))
116116
self._metrics[-1].update_state(value, sample_weight=sample_weight)
117+
118+
# noinspection PyMethodOverriding
119+
def _batch_size_from_data(self, data: any) -> int:
120+
"""Obtain the batch size from a batch of data.
121+
122+
To properly weigh the metrics for batches of different sizes, the batch size of a given batch of data is
123+
required. As the data structure differs between approximators, each concrete approximator has to specify
124+
this method.
125+
126+
Parameters
127+
----------
128+
data :
129+
The data that are passed to `compute_metrics` as keyword arguments.
130+
131+
Returns
132+
-------
133+
batch_size : int
134+
The batch size of the given data.
135+
"""
136+
raise NotImplementedError(
137+
"Correct calculation of the metrics requires obtaining the batch size from the supplied data "
138+
"for proper weighting of metrics for batches with different sizes. Please implement the "
139+
"_batch_size_from_data method for your approximator. For a given batch of data, it should "
140+
"return the corresponding batch size."
141+
)

bayesflow/approximators/backend_approximators/torch_approximator.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def compute_metrics(self, *args, **kwargs) -> dict[str, torch.Tensor]:
3333
----------
3434
*args : tuple
3535
Positional arguments passed to the metric computation function.
36-
**kwargs : dict
36+
**kwargs
3737
Keyword arguments passed to the metric computation function.
3838
3939
Returns
@@ -124,3 +124,28 @@ def _update_metrics(self, metrics: dict[str, any], sample_weight: torch.Tensor =
124124
except ValueError:
125125
self._metrics.append(keras.metrics.Mean(name=name))
126126
self._metrics[-1].update_state(value, sample_weight=sample_weight)
127+
128+
# noinspection PyMethodOverriding
129+
def _batch_size_from_data(self, data: any) -> int:
130+
"""Obtain the batch size from a batch of data.
131+
132+
To properly weigh the metrics for batches of different sizes, the batch size of a given batch of data is
133+
required. As the data structure differs between approximators, each concrete approximator has to specify
134+
this method.
135+
136+
Parameters
137+
----------
138+
data :
139+
The data that are passed to `compute_metrics` as keyword arguments.
140+
141+
Returns
142+
-------
143+
batch_size : int
144+
The batch size of the given data.
145+
"""
146+
raise NotImplementedError(
147+
"Correct calculation of the metrics requires obtaining the batch size from the supplied data "
148+
"for proper weighting of metrics for batches with different sizes. Please implement the "
149+
"_batch_size_from_data method for your approximator. For a given batch of data, it should "
150+
"return the corresponding batch size."
151+
)

0 commit comments

Comments
 (0)