@@ -19,37 +19,41 @@ class NNPE(ElementwiseTransform):
1919
2020 Parameters
2121 ----------
22- slab_scale : float
23- The scale of the slab (Cauchy ) distribution.
24- spike_scale : float
25- The scale of the spike spike (Normal ) distribution.
22+ spike_scale : float or None
23+ The scale of the spike (Normal ) distribution. Automatically determined if None (see “Notes” section) .
24+ slab_scale : float or None
25+ The scale of the slab (Cauchy ) distribution. Automatically determined if None (see “Notes” section) .
2626 seed : int or None
2727 The seed for the random number generator. If None, a random seed is used. Used instead of np.random.Generator
2828 here to enable easy serialization.
2929
3030 Notes
3131 -----
32- The spike-and-slab distribution consists of a mixture of a Cauchy (slab) and a Normal distribution (spike), which
33- are applied based on a Bernoulli random variable with p=0.5.
32+ The spike-and-slab distribution consists of a mixture of a Normal distribution (spike) and Cauchy distribution
33+ (slab), which are applied based on a Bernoulli random variable with p=0.5.
3434
35- The default scales follow [1] and expect standardized data (e.g., via the `Standardize` adapter). It is therefore
36- recommended to adapt the scales when using unstandardized training data.
35+ The scales of the spike and slab distributions can be set manually, or they are automatically determined by scaling
36+ the default scales of [1] (which expect standardized data) by the standard deviation of the input data.
3737
3838 Examples
3939 --------
4040 >>> adapter = bf.Adapter().nnpe(["x"])
4141 """
4242
43- def __init__ (self , * , slab_scale : float = 0.25 , spike_scale : float = 0.01 , seed : int = None ):
43+ DEFAULT_SLAB = 0.25
44+ DEFAULT_SPIKE = 0.01
45+
46+ def __init__ (self , * , spike_scale : float | None = None , slab_scale : float | None = None , seed : int | None = None ):
4447 super ().__init__ ()
45- self .slab_scale = slab_scale
4648 self .spike_scale = spike_scale
49+ self .slab_scale = slab_scale
4750 self .seed = seed
4851 self .rng = np .random .default_rng (seed )
4952
5053 def forward (self , data : np .ndarray , stage : str = "inference" , ** kwargs ) -> np .ndarray :
5154 """
52- Add spike‐and‐slab noise (see “Notes” section of the class docstring for details) to `data` during training.
55+ Add spike‐and‐slab noise to `data` during training, using automatic scale determination if not provided (see
56+ “Notes” section of the class docstring for details).
5357
5458 Parameters
5559 ----------
@@ -67,9 +71,21 @@ def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.nd
6771 """
6872 if stage != "training" :
6973 return data
74+
75+ # Check data validity
76+ if not np .all (np .isfinite (data )):
77+ raise ValueError ("NNPE.forward: `data` contains NaN or infinite values." )
78+
79+ # Automatically determine scales if not provided
80+ if self .spike_scale is None or self .slab_scale is None :
81+ data_std = np .std (data )
82+ spike_scale = self .spike_scale if self .spike_scale is not None else self .DEFAULT_SPIKE * data_std
83+ slab_scale = self .slab_scale if self .slab_scale is not None else self .DEFAULT_SLAB * data_std
84+
85+ # Apply spike-and-slab noise
7086 mixture_mask = self .rng .binomial (n = 1 , p = 0.5 , size = data .shape ).astype (bool )
71- noise_slab = self .rng .standard_cauchy (size = data .shape ) * self . slab_scale
72- noise_spike = self .rng .standard_normal (size = data .shape ) * self . spike_scale
87+ noise_spike = self .rng .standard_normal (size = data .shape ) * spike_scale
88+ noise_slab = self .rng .standard_cauchy (size = data .shape ) * slab_scale
7389 noise = np .where (mixture_mask , noise_slab , noise_spike )
7490 return data + noise
7591
@@ -78,4 +94,4 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
7894 return data
7995
8096 def get_config (self ) -> dict :
81- return serialize ({"slab_scale " : self .slab_scale , "spike_scale " : self .spike_scale , "seed" : self .seed })
97+ return serialize ({"spike_scale " : self .spike_scale , "slab_scale " : self .slab_scale , "seed" : self .seed })
0 commit comments