Skip to content

Commit 5ae83bc

Browse files
committed
Fix silly type check and improve readability of for loop
1 parent f62f5f2 commit 5ae83bc

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

bayesflow/datasets/disk_dataset.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,11 @@ def __getitem__(self, item) -> dict[str, np.ndarray]:
9292

9393
batch = tree_stack(batch)
9494

95-
if isinstance(self.augmentations, Mapping):
96-
for key in self.augmentations:
97-
batch[key] = self.augmentations[key](batch[key])
95+
if self.augmentations is None:
96+
pass
97+
elif isinstance(self.augmentations, Mapping):
98+
for key, fn in self.augmentations.items():
99+
batch[key] = fn(batch[key])
98100
elif isinstance(self.augmentations, Callable):
99101
batch = self.augmentations(batch)
100102
else:

bayesflow/datasets/offline_dataset.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,11 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]:
102102
for key, value in self.data.items()
103103
}
104104

105-
if isinstance(self.augmentations, Mapping):
106-
for key in self.augmentations:
107-
batch[key] = self.augmentations[key](batch[key])
105+
if self.augmentations is None:
106+
pass
107+
elif isinstance(self.augmentations, Mapping):
108+
for key, fn in self.augmentations.items():
109+
batch[key] = fn(batch[key])
108110
elif isinstance(self.augmentations, Callable):
109111
batch = self.augmentations(batch)
110112
else:

bayesflow/datasets/online_dataset.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]:
7676
"""
7777
batch = self.simulator.sample((self.batch_size,))
7878

79-
if isinstance(self.augmentations, Mapping):
80-
for key in self.augmentations:
81-
batch[key] = self.augmentations[key](batch[key])
79+
if self.augmentations is None:
80+
pass
81+
elif isinstance(self.augmentations, Mapping):
82+
for key, fn in self.augmentations.items():
83+
batch[key] = fn(batch[key])
8284
elif isinstance(self.augmentations, Callable):
8385
batch = self.augmentations(batch)
8486
else:

0 commit comments

Comments
 (0)