Skip to content

Commit db0522e

Browse files
committed
trying to fix doctests
Signed-off-by: Nathaniel <[email protected]>
1 parent db1e81d commit db0522e

File tree

1 file changed

+14
-24
lines changed

1 file changed

+14
-24
lines changed

causalpy/variable_selection_priors.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ def _relaxed_bernoulli_transform(
4141
Parameters
4242
----------
4343
p : float or PyMC variable
44-
Probability parameter
44+
Probability parameter.
4545
temperature : float, default=0.1
46-
Temperature parameter (lower = more binary)
46+
Temperature parameter (lower = more binary).
4747
4848
Returns
4949
-------
5050
function
51-
Transform function that takes uniform random variable
51+
Transform function that takes uniform random variable.
5252
"""
5353

5454
def transform(u):
@@ -84,6 +84,8 @@ class SpikeAndSlabPrior:
8484
8585
Example
8686
-------
87+
>>> import pymc as pm
88+
>>> from causalpy.variable_selection_priors import SpikeAndSlabPrior
8789
>>> spike_slab = SpikeAndSlabPrior(dims="features")
8890
>>> with pm.Model():
8991
... beta = spike_slab.create_variable("beta")
@@ -161,6 +163,8 @@ class HorseshoePrior:
161163
162164
Example
163165
-------
166+
>>> import pymc as pm
167+
>>> from causalpy.variable_selection_priors import HorseshoePrior
164168
>>> horseshoe = HorseshoePrior(dims="features")
165169
>>> with pm.Model():
166170
... beta = horseshoe.create_variable("beta")
@@ -261,18 +265,11 @@ class VariableSelectionPrior:
261265
-------
262266
>>> import pymc as pm
263267
>>> from variable_selection_priors import VariableSelectionPrior
264-
>>>
265268
>>> # Create spike-and-slab prior
266269
>>> vs_prior = VariableSelectionPrior("spike_and_slab")
267-
>>>
268270
>>> with pm.Model() as model:
269271
... # Create coefficients with variable selection
270-
... beta = vs_prior.create_prior(
271-
... name="beta",
272-
... n_params=5,
273-
... dims="features",
274-
... X=X_train, # For computing tau0 in horseshoe
275-
... )
272+
... beta = vs_prior.create_prior(name="beta", n_params=5, dims="features")
276273
"""
277274

278275
def __init__(self, prior_type: str, hyperparams: Optional[Dict[str, Any]] = None):
@@ -375,11 +372,13 @@ def create_prior(
375372
376373
Example
377374
-------
378-
>>> vs_prior = VariableSelectionPrior("horseshoe")
375+
>>> import pymc as pm
376+
>>> import pandas as pd
377+
>>> from variable_selection_priors import VariableSelectionPrior
378+
>>> X_train = pd.DataFrame{'x': [1, 2, 3, 4]}
379+
>>> vs_prior = VariableSelectionPrior("spike_and_slab")
379380
>>> with pm.Model() as model:
380-
... beta = vs_prior.create_prior(
381-
... "beta", n_params=10, dims="features", X=X_train
382-
... )
381+
... beta = vs_prior.create_prior("beta", n_params=4, dims="features")
383382
"""
384383
# Merge instance and call-specific hyperparameters
385384
default_hp = self._get_default_hyperparams(n_params, X)
@@ -450,11 +449,6 @@ def get_inclusion_probabilities(
450449
ValueError
451450
If prior_type is not 'spike_and_slab' or gamma variables not found
452451
453-
Example
454-
-------
455-
>>> result = vs_prior.get_inclusion_probabilities(idata, "beta")
456-
>>> print(f"Selected features: {result['selected']}")
457-
>>> print(f"Inclusion probs: {result['probabilities']}")
458452
"""
459453
if self.prior_type != "spike_and_slab":
460454
raise ValueError(
@@ -512,10 +506,6 @@ def get_shrinkage_factors(self, idata, param_name: str) -> Dict[str, np.ndarray]
512506
ValueError
513507
If prior_type is not 'horseshoe' or required variables not found
514508
515-
Example
516-
-------
517-
>>> result = vs_prior.get_shrinkage_factors(idata, "beta")
518-
>>> print(f"Shrinkage factors: {result['shrinkage_factors']}")
519509
"""
520510
if self.prior_type != "horseshoe":
521511
raise ValueError("Shrinkage factors only available for 'horseshoe' priors")

0 commit comments

Comments
 (0)