Skip to content

Commit 1ddf7f8

Browse files
committed
update docstrings
1 parent cdd9f77 commit 1ddf7f8

File tree

10 files changed

+242
-111
lines changed

10 files changed

+242
-111
lines changed

causalpy/data/datasets.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,22 @@ def _get_data_home() -> pathlib.Path:
4949

5050

5151
def load_data(dataset: str | None = None) -> pd.DataFrame:
52-
"""Loads the requested dataset and returns a pandas DataFrame.
52+
"""Load the requested dataset and return a pandas DataFrame.
5353
54-
:param dataset: The desired dataset to load
54+
Parameters
55+
----------
56+
dataset : str, optional
57+
The desired dataset to load. If None, raises ValueError.
58+
59+
Returns
60+
-------
61+
pd.DataFrame
62+
The loaded dataset as a pandas DataFrame.
63+
64+
Raises
65+
------
66+
ValueError
67+
If the requested dataset is not found.
5568
"""
5669

5770
if dataset in DATASETS:

causalpy/experiments/base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,14 @@ def idata(self) -> az.InferenceData:
6262
return self.model.idata
6363

6464
def print_coefficients(self, round_to: int | None = None) -> None:
65-
"""Ask the model to print its coefficients."""
65+
"""Ask the model to print its coefficients.
66+
67+
Parameters
68+
----------
69+
round_to : int, optional
70+
Number of significant figures to round to. Defaults to None,
71+
in which case 2 significant figures are used.
72+
"""
6673
self.model.print_coefficients(self.labels, round_to)
6774

6875
def plot(self, *args: Any, **kwargs: Any) -> tuple:

causalpy/experiments/diff_in_diff.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,24 @@ class DifferenceInDifferences(BaseExperiment):
4949
5050
.. note::
5151
52-
There is no pre/post intervention data distinction for DiD, we fit all the
53-
data available.
54-
:param data:
55-
A pandas dataframe
56-
:param formula:
57-
A statistical model formula
58-
:param time_variable_name:
59-
Name of the data column for the time variable
60-
:param group_variable_name:
61-
Name of the data column for the group variable
62-
:param post_treatment_variable_name:
63-
Name of the data column indicating post-treatment period (default: "post_treatment")
64-
:param model:
65-
A PyMC model for difference in differences
52+
There is no pre/post intervention data distinction for DiD, we fit
53+
all the data available.
54+
55+
Parameters
56+
----------
57+
data : pd.DataFrame
58+
A pandas dataframe.
59+
formula : str
60+
A statistical model formula.
61+
time_variable_name : str
62+
Name of the data column for the time variable.
63+
group_variable_name : str
64+
Name of the data column for the group variable.
65+
post_treatment_variable_name : str, optional
66+
Name of the data column indicating post-treatment period.
67+
Defaults to "post_treatment".
68+
model : PyMCModel or RegressorMixin, optional
69+
A PyMC model for difference in differences. Defaults to None.
6670
6771
Example
6872
--------

causalpy/experiments/instrumental_variable.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,30 @@
2727

2828

2929
class InstrumentalVariable(BaseExperiment):
30-
"""
31-
A class to analyse instrumental variable style experiments.
32-
33-
:param instruments_data: A pandas dataframe of instruments
34-
for our treatment variable. Should contain
35-
instruments Z, and treatment t
36-
:param data: A pandas dataframe of covariates for fitting
37-
the focal regression of interest. Should contain covariates X
38-
including treatment t and outcome y
39-
:param instruments_formula: A statistical model formula for
40-
the instrumental stage regression
41-
e.g. t ~ 1 + z1 + z2 + z3
42-
:param formula: A statistical model formula for the \n
43-
focal regression e.g. y ~ 1 + t + x1 + x2 + x3
44-
:param model: A PyMC model
45-
:param priors: An optional dictionary of priors for the
46-
mus and sigmas of both regressions. If priors are not
47-
specified we will substitute MLE estimates for the beta
48-
coefficients. Greater control can be achieved
49-
by specifying the priors directly e.g. priors = {
50-
"mus": [0, 0],
51-
"sigmas": [1, 1],
52-
"eta": 2,
53-
"lkj_sd": 2,
54-
}
30+
"""A class to analyse instrumental variable style experiments.
31+
32+
Parameters
33+
----------
34+
instruments_data : pd.DataFrame
35+
A pandas dataframe of instruments for our treatment variable.
36+
Should contain instruments Z, and treatment t.
37+
data : pd.DataFrame
38+
A pandas dataframe of covariates for fitting the focal regression
39+
of interest. Should contain covariates X including treatment t and
40+
outcome y.
41+
instruments_formula : str
42+
A statistical model formula for the instrumental stage regression,
43+
e.g. ``t ~ 1 + z1 + z2 + z3``.
44+
formula : str
45+
A statistical model formula for the focal regression,
46+
e.g. ``y ~ 1 + t + x1 + x2 + x3``.
47+
model : BaseExperiment, optional
48+
A PyMC model. Defaults to None.
49+
priors : dict, optional
50+
Dictionary of priors for the mus and sigmas of both regressions.
51+
If priors are not specified we will substitute MLE estimates for
52+
the beta coefficients. Example: ``priors = {"mus": [0, 0],
53+
"sigmas": [1, 1], "eta": 2, "lkj_sd": 2}``.
5554
5655
Example
5756
--------

causalpy/experiments/inverse_propensity_weighting.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,23 @@
3131

3232

3333
class InversePropensityWeighting(BaseExperiment):
34-
"""
35-
A class to analyse inverse propensity weighting experiments.
34+
"""A class to analyse inverse propensity weighting experiments.
3635
37-
:param data:
38-
A pandas dataframe
39-
:param formula:
40-
A statistical model formula for the propensity model
41-
:param outcome_variable
42-
A string denoting the outcome variable in datq to be reweighted
43-
:param weighting_scheme:
44-
A string denoting which weighting scheme to use among: 'raw', 'robust',
45-
'doubly robust' or 'overlap'. See Aronow and Miller "Foundations
46-
of Agnostic Statistics" for discussion and computation of these
47-
weighting schemes.
48-
:param model:
49-
A PyMC model
36+
Parameters
37+
----------
38+
data : pd.DataFrame
39+
A pandas dataframe.
40+
formula : str
41+
A statistical model formula for the propensity model.
42+
outcome_variable : str
43+
A string denoting the outcome variable in data to be reweighted.
44+
weighting_scheme : str
45+
A string denoting which weighting scheme to use among: 'raw',
46+
'robust', 'doubly robust' or 'overlap'. See Aronow and Miller
47+
"Foundations of Agnostic Statistics" for discussion and computation
48+
of these weighting schemes.
49+
model : BaseExperiment, optional
50+
A PyMC model. Defaults to None.
5051
5152
Example
5253
--------

causalpy/plot_utils.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,28 @@ def plot_xY(
3535
hdi_prob: float = 0.94,
3636
label: str | None = None,
3737
) -> Tuple[Line2D, PolyCollection]:
38-
"""
39-
Utility function to plot HDI intervals.
40-
41-
:param x:
42-
Pandas datetime index or numpy array of x-axis values
43-
:param y:
44-
Xarray data array of y-axis data
45-
:param ax:
46-
Matplotlib ax object
47-
:param plot_hdi_kwargs:
48-
Dictionary of keyword arguments passed to ax.plot()
49-
:param hdi_prob:
50-
The size of the HDI, default is 0.94
51-
:param label:
52-
The plot label
38+
"""Plot HDI intervals.
39+
40+
Parameters
41+
----------
42+
x : pd.DatetimeIndex, np.ndarray, pd.Index, pd.Series, or ExtensionArray
43+
Pandas datetime index or numpy array of x-axis values.
44+
Y : xr.DataArray
45+
Xarray data array of y-axis data.
46+
ax : plt.Axes
47+
Matplotlib axes object.
48+
plot_hdi_kwargs : dict, optional
49+
Dictionary of keyword arguments passed to ax.plot().
50+
hdi_prob : float, optional
51+
The size of the HDI. Default is 0.94.
52+
label : str, optional
53+
The plot label.
54+
55+
Returns
56+
-------
57+
tuple
58+
Tuple of (Line2D, PolyCollection) handles for the plot line and
59+
HDI patch.
5360
"""
5461

5562
if plot_hdi_kwargs is None:
@@ -86,13 +93,20 @@ def get_hdi_to_df(
8693
x: xr.DataArray,
8794
hdi_prob: float = 0.94,
8895
) -> pd.DataFrame:
89-
"""
90-
Utility function to calculate and recover HDI intervals.
96+
"""Calculate and recover HDI intervals.
97+
98+
Parameters
99+
----------
100+
x : xr.DataArray
101+
Xarray data array.
102+
hdi_prob : float, optional
103+
The size of the HDI. Default is 0.94.
91104
92-
:param x:
93-
Xarray data array
94-
:param hdi_prob:
95-
The size of the HDI, default is 0.94
105+
Returns
106+
-------
107+
pd.DataFrame
108+
DataFrame containing the HDI intervals with 'lower' and 'higher'
109+
columns.
96110
"""
97111
hdi_result = az.hdi(x, hdi_prob=hdi_prob)
98112

causalpy/pymc_models.py

Lines changed: 84 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,15 @@ def __init__(
173173
priors: dict[str, Any] | None = None,
174174
) -> None:
175175
"""
176-
:param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the
177-
:func:`pymc.sample` function. Defaults to an empty dictionary.
176+
Parameters
177+
----------
178+
sample_kwargs : dict, optional
179+
Dictionary of kwargs that get unpacked and passed to the
180+
:func:`pymc.sample` function. Defaults to an empty dictionary
181+
if None.
182+
priors : dict, optional
183+
Dictionary of priors for the model. Defaults to None, in which
184+
case default priors are used.
178185
"""
179186
super().__init__()
180187
self.idata = None
@@ -224,8 +231,23 @@ def _data_setter(self, X: xr.DataArray) -> None:
224231
def fit(
225232
self, X: xr.DataArray, y: xr.DataArray, coords: Dict[str, Any] | None = None
226233
) -> az.InferenceData:
227-
"""Draw samples from posterior, prior predictive, and posterior predictive
228-
distributions, placing them in the model's idata attribute.
234+
"""Draw samples from posterior, prior predictive, and posterior
235+
predictive distributions.
236+
237+
Parameters
238+
----------
239+
X : xr.DataArray
240+
Input features as an xarray DataArray.
241+
y : xr.DataArray
242+
Target variable as an xarray DataArray.
243+
coords : dict, optional
244+
Dictionary with coordinate names for named dimensions.
245+
Defaults to None.
246+
247+
Returns
248+
-------
249+
az.InferenceData
250+
InferenceData object containing the samples.
229251
"""
230252

231253
# Ensure random_seed is used in sample_prior_predictive() and
@@ -356,6 +378,16 @@ def calculate_cumulative_impact(self, impact: xr.DataArray) -> xr.DataArray:
356378
def print_coefficients(
357379
self, labels: list[str], round_to: int | None = None
358380
) -> None:
381+
"""Print the model coefficients with their labels.
382+
383+
Parameters
384+
----------
385+
labels : list of str
386+
List of strings representing the coefficient names.
387+
round_to : int, optional
388+
Number of significant figures to round to. Defaults to None,
389+
in which case 2 significant figures are used.
390+
"""
359391
if self.idata is None:
360392
raise RuntimeError("Model has not been fit")
361393

@@ -627,19 +659,27 @@ def build_model( # type: ignore
627659
coords: Dict[str, Any],
628660
priors: Dict[str, Any],
629661
) -> None:
630-
"""Specify model with treatment regression and focal regression data and priors
631-
632-
:param X: A pandas dataframe used to predict our outcome y
633-
:param Z: A pandas dataframe used to predict our treatment variable t
634-
:param y: An array of values representing our focal outcome y
635-
:param t: An array of values representing the treatment t of
636-
which we're interested in estimating the causal impact
637-
:param coords: A dictionary with the coordinate names for our
638-
instruments and covariates
639-
:param priors: An optional dictionary of priors for the mus and
640-
sigmas of both regressions
641-
:code:`priors = {"mus": [0, 0], "sigmas": [1, 1],
642-
"eta": 2, "lkj_sd": 2}`
662+
"""Specify model with treatment regression and focal regression
663+
data and priors.
664+
665+
Parameters
666+
----------
667+
X : np.ndarray
668+
Array used to predict our outcome y.
669+
Z : np.ndarray
670+
Array used to predict our treatment variable t.
671+
y : np.ndarray
672+
Array of values representing our focal outcome y.
673+
t : np.ndarray
674+
Array representing the treatment t of which we're interested
675+
in estimating the causal impact.
676+
coords : dict
677+
Dictionary with the coordinate names for our instruments and
678+
covariates.
679+
priors : dict
680+
Dictionary of priors for the mus and sigmas of both
681+
regressions. Example: ``priors = {"mus": [0, 0],
682+
"sigmas": [1, 1], "eta": 2, "lkj_sd": 2}``.
643683
"""
644684

645685
# --- Priors ---
@@ -725,13 +765,33 @@ def fit( # type: ignore
725765
priors: Dict[str, Any],
726766
ppc_sampler: str | None = None,
727767
) -> az.InferenceData:
728-
"""Draw samples from posterior distribution and potentially
729-
from the prior and posterior predictive distributions. The
730-
fit call can take values for the
731-
ppc_sampler = ['jax', 'pymc', None]
732-
We default to None, so the user can determine if they wish
733-
to spend time sampling the posterior predictive distribution
734-
independently.
768+
"""Draw samples from posterior distribution and potentially from
769+
the prior and posterior predictive distributions.
770+
771+
Parameters
772+
----------
773+
X : np.ndarray
774+
Array used to predict our outcome y.
775+
Z : np.ndarray
776+
Array used to predict our treatment variable t.
777+
y : np.ndarray
778+
Array of values representing our focal outcome y.
779+
t : np.ndarray
780+
Array representing the treatment variable.
781+
coords : dict
782+
Dictionary with coordinate names for named dimensions.
783+
priors : dict
784+
Dictionary of priors for the model.
785+
ppc_sampler : str, optional
786+
Sampler for posterior predictive distribution. Can be 'jax',
787+
'pymc', or None. Defaults to None, so the user can determine
788+
if they wish to spend time sampling the posterior predictive
789+
distribution independently.
790+
791+
Returns
792+
-------
793+
az.InferenceData
794+
InferenceData object containing the samples.
735795
"""
736796

737797
# Ensure random_seed is used in sample_prior_predictive() and

0 commit comments

Comments
 (0)