Skip to content

Commit e2f93dc

Browse files
committed
Only apply NNPE during training
1 parent 3a98176 commit e2f93dc

File tree

1 file changed

+3
-1
lines changed
  • bayesflow/adapters/transforms

1 file changed

+3
-1
lines changed

bayesflow/adapters/transforms/nnpe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def __init__(self, *, slab_scale: float = 0.25, spike_scale: float = 0.01, seed:
4747
self.seed = seed
4848
self.rng = np.random.default_rng(seed)
4949

50-
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
50+
def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
51+
if stage != "training":
52+
return data
5153
mixture_mask = self.rng.binomial(n=1, p=0.5, size=data.shape).astype(bool)
5254
noise_slab = self.rng.standard_cauchy(size=data.shape) * self.slab_scale
5355
noise_spike = self.rng.standard_normal(size=data.shape) * self.spike_scale

0 commit comments

Comments
 (0)