Skip to content

Commit 554ed13

Browse files
Merge pull request #268 from eodole/dev
Added Documentation to Adapter Transforms
2 parents 0cc5b58 + 25cc326 commit 554ed13

File tree

8 files changed

+66
-11
lines changed

8 files changed

+66
-11
lines changed

bayesflow/adapters/transforms/concatenate.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,22 @@
1212

1313
@serializable(package="bayesflow.adapters")
1414
class Concatenate(Transform):
15-
"""Concatenate multiple arrays into a new key.
16-
Parameters:
17-
18-
keys:
19-
20-
into:
15+
"""Concatenate multiple arrays into a new key. Used to specify how data variables should be treated by the network.
2116
17+
Parameters:
18+
keys: Input a list of strings, where the strings are the names of data variables.
19+
into: A string telling the network how to use the variables named in keys.
20+
axis: integer specifing along which axis to concatonate the keys. The last axis is used by default.
21+
22+
Example:
23+
Suppose you have a simulator that generates variables "beta" and "sigma" from priors and then observation
24+
variables "x" and "y". We can then use concatonate in the following way
25+
26+
adapter = (
27+
bf.Adapter()
28+
.concatenate(["beta", "sigma"], into="inference_variables")
29+
.concatenate(["x", "y"], into="summary_variables")
30+
)
2231
"""
2332

2433
def __init__(self, keys: Sequence[str], *, into: str, axis: int = -1):

bayesflow/adapters/transforms/constrain.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class Constrain(ElementwiseTransform):
3232
3333
3434
Examples:
35-
Let sigma be the standard deviation of a normal distribution,
35+
1) Let sigma be the standard deviation of a normal distribution,
3636
then sigma should always be greater than zero.
3737
3838
Useage:
@@ -41,8 +41,8 @@ class Constrain(ElementwiseTransform):
4141
.constrain("sigma", lower=0)
4242
)
4343
44-
Suppose p is the parameter for a binomial distribution where p must be in [0,1]
45-
then we would constrain the neural network to estimate p in the following way.
44+
2 ) Suppose p is the parameter for a binomial distribution where p must be in
45+
[0,1] then we would constrain the neural network to estimate p in the following way.
4646
4747
Usage:
4848
adapter = (

bayesflow/adapters/transforms/drop.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,26 @@
1111

1212
@serializable(package="bayesflow.adapters")
1313
class Drop(Transform):
14+
"""
15+
Transform to drop variables from further calculation.
16+
17+
Parameters:
18+
keys: list of strings, containing names of data variables that should be dropped
19+
20+
Example:
21+
22+
>>> import bayesflow as bf
23+
>>> a = [1, 2, 3, 4]
24+
>>> b = [[1, 2], [3, 4]]
25+
>>> c = [[5, 6, 7, 8]]
26+
>>> dat = dict(a=a, b=b, c=c)
27+
>>> dat
28+
{'a': [1, 2, 3, 4], 'b': [[1, 2], [3, 4]], 'c': [[5, 6, 7, 8]]}
29+
>>> drop = bf.adapters.transforms.Drop(("b", "c"))
30+
>>> drop.forward(dat)
31+
{'a': [1, 2, 3, 4]}
32+
"""
33+
1434
def __init__(self, keys: Sequence[str]):
1535
self.keys = keys
1636

bayesflow/adapters/transforms/elementwise_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
@serializable(package="bayesflow.adapters")
66
class ElementwiseTransform:
7+
"""Base class on which other transforms are based"""
8+
79
def __call__(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
810
if inverse:
911
return self.inverse(data, **kwargs)

bayesflow/adapters/transforms/keep.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ class Keep(Transform):
2525
2626
adapter = (
2727
bf.adapters.Adapter()
28-
# only keep theta and x
29-
.keep(("theta", "x"))
28+
# drop data from unneeded priors alpha, and r
29+
# only keep theta and x
30+
.keep(("theta", "x"))
3031
)
3132
3233
Example:

bayesflow/adapters/transforms/one_hot.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99

1010
@serializable(package="bayesflow.adapters")
1111
class OneHot(ElementwiseTransform):
12+
"""
13+
Changes data to be one-hot encoded.
14+
"""
15+
1216
def __init__(self, num_classes: int):
1317
super().__init__()
1418
self.num_classes = num_classes

bayesflow/adapters/transforms/standardize.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,21 @@
1010

1111
@serializable(package="bayesflow.adapters")
1212
class Standardize(ElementwiseTransform):
13+
"""
14+
Transform that when applied standardizes data using typical z-score standardization i.e. for some unstandardized
15+
data x the standardized version z would be
16+
17+
z = (x - mean(x))/std(x)
18+
19+
Parameters:
20+
mean: integer or float used to specify a mean if known but will be estimated from data when not provided
21+
std: integer or float used to specify a standard devation if known but will be estimated from data when not provided
22+
axis: integer representing a specific axis along which standardization should take place. By default
23+
standardization happens individually for each dimension
24+
momentum: float in (0,1) specifying the momentum during training
25+
26+
"""
27+
1328
def __init__(
1429
self,
1530
mean: int | float | np.ndarray = None,

bayesflow/adapters/transforms/transform.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
@serializable(package="bayesflow.adapters")
66
class Transform:
7+
"""
8+
Base class on which other transforms are based
9+
"""
10+
711
def __call__(self, data: dict[str, np.ndarray], *, inverse: bool = False, **kwargs) -> dict[str, np.ndarray]:
812
if inverse:
913
return self.inverse(data, **kwargs)

0 commit comments

Comments
 (0)