Skip to content

Commit 1bf0ae8

Browse files
committed
Merge branch 'dev' of https://github.com/stefanradev93/BayesFlow into dev
2 parents 79d727a + c3c1456 commit 1bf0ae8

File tree

7 files changed

+32
-13
lines changed

7 files changed

+32
-13
lines changed

bayesflow/adapters/adapter.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,15 @@ def get_config(self) -> dict:
7979

8080
return serialize(config)
8181

82-
def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
82+
def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -> dict[str, np.ndarray]:
8383
"""Apply the transforms in the forward direction.
8484
8585
Parameters
8686
----------
8787
data : dict
8888
The data to be transformed.
89+
stage : str, one of ["training", "validation", "inference"]
90+
The stage the function is called in.
8991
**kwargs : dict
9092
Additional keyword arguments passed to each transform.
9193
@@ -97,17 +99,19 @@ def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
9799
data = data.copy()
98100

99101
for transform in self.transforms:
100-
data = transform(data, **kwargs)
102+
data = transform(data, stage=stage, **kwargs)
101103

102104
return data
103105

104-
def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, any]:
106+
def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kwargs) -> dict[str, any]:
105107
"""Apply the transforms in the inverse direction.
106108
107109
Parameters
108110
----------
109111
data : dict
110112
The data to be transformed.
113+
stage : str, one of ["training", "validation", "inference"]
114+
The stage the function is called in.
111115
**kwargs : dict
112116
Additional keyword arguments passed to each transform.
113117
@@ -119,11 +123,13 @@ def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, any]:
119123
data = data.copy()
120124

121125
for transform in reversed(self.transforms):
122-
data = transform(data, inverse=True, **kwargs)
126+
data = transform(data, stage=stage, inverse=True, **kwargs)
123127

124128
return data
125129

126-
def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs) -> dict[str, np.ndarray]:
130+
def __call__(
131+
self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs
132+
) -> dict[str, np.ndarray]:
127133
"""Apply the transforms in the given direction.
128134
129135
Parameters
@@ -132,6 +138,8 @@ def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs)
132138
The data to be transformed.
133139
inverse : bool, optional
134140
If False, apply the forward transform, else apply the inverse transform (default False).
141+
stage : str, one of ["training", "validation", "inference"]
142+
The stage the function is called in.
135143
**kwargs
136144
Additional keyword arguments passed to each transform.
137145
@@ -141,9 +149,9 @@ def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs)
141149
The transformed data.
142150
"""
143151
if inverse:
144-
return self.inverse(data, **kwargs)
152+
return self.inverse(data, stage=stage, **kwargs)
145153

146-
return self.forward(data, **kwargs)
154+
return self.forward(data, stage=stage, **kwargs)
147155

148156
def __repr__(self):
149157
result = ""

bayesflow/adapters/transforms/standardize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def get_config(self) -> dict:
8787
}
8888
return serialize(config)
8989

90-
def forward(self, data: np.ndarray, stage: str = "training", **kwargs) -> np.ndarray:
90+
def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
9191
if self.axis is None:
9292
self.axis = tuple(range(data.ndim - 1))
9393

bayesflow/datasets/disk_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
batch_size: int,
3232
load_fn: callable = None,
3333
adapter: Adapter | None,
34+
stage: str = "training",
3435
**kwargs,
3536
):
3637
super().__init__(**kwargs)
@@ -39,6 +40,7 @@ def __init__(
3940
self.load_fn = load_fn or pickle_load
4041
self.adapter = adapter
4142
self.files = list(map(str, self.root.glob(pattern)))
43+
self.stage = stage
4244

4345
self.shuffle()
4446

@@ -55,7 +57,7 @@ def __getitem__(self, item) -> dict[str, np.ndarray]:
5557
batch = tree_stack(batch)
5658

5759
if self.adapter is not None:
58-
batch = self.adapter(batch)
60+
batch = self.adapter(batch, stage=self.stage)
5961

6062
return batch
6163

bayesflow/datasets/offline_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@ def __init__(
2121
batch_size: int,
2222
adapter: Adapter | None,
2323
num_samples: int = None,
24+
*,
25+
stage: str = "training",
2426
**kwargs,
2527
):
2628
super().__init__(**kwargs)
2729
self.batch_size = batch_size
2830
self.data = data
2931
self.adapter = adapter
32+
self.stage = stage
3033

3134
if num_samples is None:
3235
self.num_samples = self._get_num_samples_from_data(data)
@@ -52,7 +55,7 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]:
5255
}
5356

5457
if self.adapter is not None:
55-
batch = self.adapter(batch)
58+
batch = self.adapter(batch, stage=self.stage)
5659

5760
return batch
5861

bayesflow/datasets/online_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def __init__(
1616
batch_size: int,
1717
num_batches: int,
1818
adapter: Adapter | None,
19+
*,
20+
stage: str = "training",
1921
**kwargs,
2022
):
2123
super().__init__(**kwargs)
@@ -24,12 +26,13 @@ def __init__(
2426
self._num_batches = num_batches
2527
self.adapter = adapter
2628
self.simulator = simulator
29+
self.stage = stage
2730

2831
def __getitem__(self, item: int) -> dict[str, np.ndarray]:
2932
batch = self.simulator.sample((self.batch_size,))
3033

3134
if self.adapter is not None:
32-
batch = self.adapter(batch)
35+
batch = self.adapter(batch, stage=self.stage)
3336

3437
return batch
3538

bayesflow/datasets/rounds_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def __init__(
1818
num_batches: int,
1919
epochs_per_round: int,
2020
adapter: Adapter | None,
21+
*,
22+
stage: str = "training",
2123
**kwargs,
2224
):
2325
super().__init__(**kwargs)
@@ -27,6 +29,7 @@ def __init__(
2729
self.batch_size = batch_size
2830
self.adapter = adapter
2931
self.epoch = 0
32+
self.stage = stage
3033

3134
if epochs_per_round == 1:
3235
logging.warning(
@@ -45,7 +48,7 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]:
4548
batch = self.batches[item]
4649

4750
if self.adapter is not None:
48-
batch = self.adapter(batch)
51+
batch = self.adapter(batch, stage=self.stage)
4952

5053
return batch
5154

examples/Linear_Regression_Starter.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@
376376
"\n",
377377
"To ensure that the training data generated by the simulator can be used for deep learning, we have do a bunch of transformations via `adapter` objects. They provides multiple flexible functionalities, from standardization to renaming, and so on. \n",
378378
"\n",
379-
"Below, we build our own `adapter` from scratch but later on, `BayesFlo` will also provide default adapters that will already automate most of the commonly required steps."
379+
"Below, we build our own `adapter` from scratch but later on, `BayesFlow` will also provide default adapters that will already automate most of the commonly required steps."
380380
]
381381
},
382382
{

0 commit comments

Comments
 (0)