Skip to content

Commit 7e3911b

Browse files
committed
ran linter
1 parent 4218b70 commit 7e3911b

File tree

2 files changed

+33
-44
lines changed

2 files changed

+33
-44
lines changed

bayesflow/adapters/adapter.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
ToArray,
2929
Transform,
3030
RandomSubsample,
31-
Take
31+
Take,
3232
)
3333
from .transforms.filter_transform import Predicate
3434

@@ -543,16 +543,17 @@ def one_hot(self, keys: str | Sequence[str], num_classes: int):
543543
transform = MapTransform({key: OneHot(num_classes=num_classes) for key in keys})
544544
self.transforms.append(transform)
545545
return self
546-
547-
def random_subsample(self,
546+
547+
def random_subsample(
548+
self,
548549
key: str | Sequence[str],
549550
*,
550551
sample_size: int | float,
551-
axis: int=-1,
552+
axis: int = -1,
552553
**kwargs,
553554
):
554555
"""
555-
Append a :py:class:`~transforms.SubsampleArray` transform to the adapter.
556+
Append a :py:class:`~transforms.RandomSubsample` transform to the adapter.
556557
557558
Parameters
558559
----------
@@ -564,22 +565,19 @@ def random_subsample(self,
564565
Names of variables to exclude from the transform.
565566
**kwargs : dict
566567
Additional keyword arguments passed to the transform.
567-
568+
568569
"""
569-
570-
571-
if isinstance(key, Sequence[str]) and len(keys) >1:
572-
TypeError("`key` should be either a string or a list of length one. Only one dataset may be modified at a time.")
570+
571+
if isinstance(key, Sequence[str]) and len(key) > 1:
572+
TypeError(
573+
"`key` should be either a string or a list of length one. Only one dataset may be modified at a time."
574+
)
573575

574576
if isinstance(key, str):
575577
keys = [key]
576578

577579
transform = MapTransform(
578-
transform_map={
579-
key:RandomSubsample(sample_size=sample_size, axis=axis)
580-
for key in keys
581-
}
582-
580+
transform_map={key: RandomSubsample(sample_size=sample_size, axis=axis) for key in keys}
583581
)
584582

585583
self.transforms.append(transform)
@@ -644,14 +642,16 @@ def standardize(
644642
self.transforms.append(transform)
645643
return self
646644

647-
def take(self,
648-
indices,
645+
def take(
646+
self,
647+
indices,
649648
axis,
650649
*,
651650
predicate: Predicate = None,
652651
include: str | Sequence[str] = None,
653652
exclude: str | Sequence[str] = None,
654-
**kwargs,):
653+
**kwargs,
654+
):
655655
"""
656656
Append a :py:class:`~transforms.Take` transform to the adapter.
657657
@@ -664,7 +664,7 @@ def take(self,
664664
exclude : str or Sequence of str, optional
665665
Names of variables to exclude from the transform.
666666
**kwargs : dict
667-
Additional keyword arguments passed to the transform. """
667+
Additional keyword arguments passed to the transform."""
668668
transform = FilterTransform(
669669
transform_constructor=Take(indices=indices, axis=axis),
670670
predicate=predicate,
@@ -674,8 +674,7 @@ def take(self,
674674
)
675675
self.transforms.append(transform)
676676
return self
677-
678-
677+
679678
def to_array(
680679
self,
681680
*,
@@ -693,7 +692,7 @@ def to_array(
693692
include : str or Sequence of str, optional
694693
Names of variables to include in the transform.
695694
exclude : str or Sequence of str, optional
696-
Names of variabxles to exclude from the transform.
695+
Names of variables to exclude from the transform.
697696
**kwargs : dict
698697
Additional keyword arguments passed to the transform.
699698
"""
@@ -706,9 +705,3 @@ def to_array(
706705
)
707706
self.transforms.append(transform)
708707
return self
709-
710-
711-
712-
713-
714-

bayesflow/adapters/transforms/random_subsample.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
2-
from keras.saving import register_keras_serializable as serializable
3-
2+
from bayesflow.utils.serialization import serializable
43
from .elementwise_transform import ElementwiseTransform
54

65

@@ -14,31 +13,28 @@ class RandomSubsample(ElementwiseTransform):
1413
"""
1514

1615
def __init__(
17-
self,
18-
sample_size: int | float,
19-
axis: int = -1,
20-
):
16+
self,
17+
sample_size: int | float,
18+
axis: int = -1,
19+
):
2120
super().__init__()
2221
if isinstance(sample_size, float):
23-
if sample_size <= 0 or sample_size >= 1:
22+
if sample_size <= 0 or sample_size >= 1:
2423
ValueError("Sample size as a percentage must be a float between 0 and 1 exclustive. ")
2524
self.sample_size = sample_size
26-
self.axis = axis
27-
25+
self.axis = axis
2826

2927
def forward(self, data: np.ndarray):
30-
31-
axis = self.axis
28+
axis = self.axis
3229
max_sample_size = data.shape[axis]
33-
30+
3431
if isinstance(self.sample_size, int):
3532
sample_size = self.sample_size
36-
else:
33+
else:
3734
sample_size = np.round(self.sample_size * max_sample_size)
3835

39-
sample_indices = np.random.permutation(max_sample_size)[
40-
0 : sample_size - 1
41-
] # random sample without replacement
36+
# random sample without replacement
37+
sample_indices = np.random.permutation(max_sample_size)[0 : sample_size - 1]
4238

4339
return np.take(data, sample_indices, axis)
4440

0 commit comments

Comments
 (0)