Skip to content

Commit a611f70

Browse files
committed
[no ci] minor fixes to RandomSubsample transform
1 parent ca9e245 commit a611f70

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

bayesflow/adapters/transforms/random_subsample.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ class RandomSubsample(ElementwiseTransform):
88
"""
99
A transform that takes a random subsample of the data within an axis.
1010
11-
Example: adapter.random_subsample("x", sample_size = 3, axis = -1)
11+
Examples
12+
--------
1213
14+
>>> adapter = bf.Adapter().random_subsample("x", sample_size=3, axis=-1)
1315
"""
1416

1517
def __init__(
@@ -20,23 +22,22 @@ def __init__(
2022
super().__init__()
2123
if isinstance(sample_size, float):
2224
if sample_size <= 0 or sample_size >= 1:
23-
ValueError("Sample size as a percentage must be a float between 0 and 1 exclusive. ")
25+
raise ValueError("Sample size as a percentage must be a float between 0 and 1 exclusive. ")
2426
self.sample_size = sample_size
2527
self.axis = axis
2628

2729
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
28-
axis = self.axis
29-
max_sample_size = data.shape[axis]
30+
max_sample_size = data.shape[self.axis]
3031

3132
if isinstance(self.sample_size, int):
3233
sample_size = self.sample_size
3334
else:
3435
sample_size = np.round(self.sample_size * max_sample_size)
3536

3637
# random sample without replacement
37-
sample_indices = np.random.permutation(max_sample_size)[0 : sample_size - 1]
38+
sample_indices = np.random.permutation(max_sample_size)[:sample_size]
3839

39-
return np.take(data, sample_indices, axis)
40+
return np.take(data, sample_indices, self.axis)
4041

4142
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
4243
# non invertible transform

0 commit comments

Comments
 (0)