@@ -8,8 +8,10 @@ class RandomSubsample(ElementwiseTransform):
88 """
99 A transform that takes a random subsample of the data within an axis.
1010
11- Example: adapter.random_subsample("x", sample_size = 3, axis = -1)
11+ Examples
12+ --------
1213
14+ >>> adapter = bf.Adapter().random_subsample("x", sample_size=3, axis=-1)
1315 """
1416
1517 def __init__ (
@@ -20,23 +22,22 @@ def __init__(
2022 super ().__init__ ()
2123 if isinstance (sample_size , float ):
2224 if sample_size <= 0 or sample_size >= 1 :
23- ValueError ("Sample size as a percentage must be a float between 0 and 1 exclusive. " )
25+ raise ValueError ("Sample size as a percentage must be a float between 0 and 1 exclusive. " )
2426 self .sample_size = sample_size
2527 self .axis = axis
2628
2729 def forward (self , data : np .ndarray , ** kwargs ) -> np .ndarray :
28- axis = self .axis
29- max_sample_size = data .shape [axis ]
30+ max_sample_size = data .shape [self .axis ]
3031
3132 if isinstance (self .sample_size , int ):
3233 sample_size = self .sample_size
3334 else :
3435 sample_size = np .round (self .sample_size * max_sample_size )
3536
3637 # random sample without replacement
37- sample_indices = np .random .permutation (max_sample_size )[0 : sample_size - 1 ]
38+ sample_indices = np .random .permutation (max_sample_size )[: sample_size ]
3839
39- return np .take (data , sample_indices , axis )
40+ return np .take (data , sample_indices , self . axis )
4041
4142 def inverse (self , data : np .ndarray , ** kwargs ) -> np .ndarray :
4243 # non invertible transform
0 commit comments