@@ -41,14 +41,14 @@ def _relaxed_bernoulli_transform(
4141 Parameters
4242 ----------
4343 p : float or PyMC variable
44- Probability parameter
44+ Probability parameter.
4545 temperature : float, default=0.1
46- Temperature parameter (lower = more binary)
46+ Temperature parameter (lower = more binary).
4747
4848 Returns
4949 -------
5050 function
51- Transform function that takes uniform random variable
51+ Transform function that takes uniform random variable.
5252 """
5353
5454 def transform (u ):
@@ -84,6 +84,8 @@ class SpikeAndSlabPrior:
8484
8585 Example
8686 -------
87+ >>> import pymc as pm
88+ >>> from causalpy.variable_selection_priors import SpikeAndSlabPrior
8789 >>> spike_slab = SpikeAndSlabPrior(dims="features")
8890 >>> with pm.Model():
8991 ... beta = spike_slab.create_variable("beta")
@@ -161,6 +163,8 @@ class HorseshoePrior:
161163
162164 Example
163165 -------
166+ >>> import pymc as pm
167+ >>> from causalpy.variable_selection_priors import HorseshoePrior
164168 >>> horseshoe = HorseshoePrior(dims="features")
165169 >>> with pm.Model():
166170 ... beta = horseshoe.create_variable("beta")
@@ -261,18 +265,11 @@ class VariableSelectionPrior:
261265 -------
262266 >>> import pymc as pm
263267 >>> from variable_selection_priors import VariableSelectionPrior
264- >>>
265268 >>> # Create spike-and-slab prior
266269 >>> vs_prior = VariableSelectionPrior("spike_and_slab")
267- >>>
268270 >>> with pm.Model() as model:
269271 ... # Create coefficients with variable selection
270- ... beta = vs_prior.create_prior(
271- ... name="beta",
272- ... n_params=5,
273- ... dims="features",
274- ... X=X_train, # For computing tau0 in horseshoe
275- ... )
272+ ... beta = vs_prior.create_prior(name="beta", n_params=5, dims="features")
276273 """
277274
278275 def __init__ (self , prior_type : str , hyperparams : Optional [Dict [str , Any ]] = None ):
@@ -375,11 +372,13 @@ def create_prior(
375372
376373 Example
377374 -------
378- >>> vs_prior = VariableSelectionPrior("horseshoe")
375+ >>> import pymc as pm
376+ >>> import pandas as pd
377+ >>> from variable_selection_priors import VariableSelectionPrior
378+ >>> X_train = pd.DataFrame{'x': [1, 2, 3, 4]}
379+ >>> vs_prior = VariableSelectionPrior("spike_and_slab")
379380 >>> with pm.Model() as model:
380- ... beta = vs_prior.create_prior(
381- ... "beta", n_params=10, dims="features", X=X_train
382- ... )
381+ ... beta = vs_prior.create_prior("beta", n_params=4, dims="features")
383382 """
384383 # Merge instance and call-specific hyperparameters
385384 default_hp = self ._get_default_hyperparams (n_params , X )
@@ -450,11 +449,6 @@ def get_inclusion_probabilities(
450449 ValueError
451450 If prior_type is not 'spike_and_slab' or gamma variables not found
452451
453- Example
454- -------
455- >>> result = vs_prior.get_inclusion_probabilities(idata, "beta")
456- >>> print(f"Selected features: {result['selected']}")
457- >>> print(f"Inclusion probs: {result['probabilities']}")
458452 """
459453 if self .prior_type != "spike_and_slab" :
460454 raise ValueError (
@@ -512,10 +506,6 @@ def get_shrinkage_factors(self, idata, param_name: str) -> Dict[str, np.ndarray]
512506 ValueError
513507 If prior_type is not 'horseshoe' or required variables not found
514508
515- Example
516- -------
517- >>> result = vs_prior.get_shrinkage_factors(idata, "beta")
518- >>> print(f"Shrinkage factors: {result['shrinkage_factors']}")
519509 """
520510 if self .prior_type != "horseshoe" :
521511 raise ValueError ("Shrinkage factors only available for 'horseshoe' priors" )
0 commit comments