@@ -173,8 +173,15 @@ def __init__(
173173 priors : dict [str , Any ] | None = None ,
174174 ) -> None :
175175 """
176- :param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the
177- :func:`pymc.sample` function. Defaults to an empty dictionary.
176+ Parameters
177+ ----------
178+ sample_kwargs : dict, optional
179+ Dictionary of kwargs that get unpacked and passed to the
180+ :func:`pymc.sample` function. Defaults to an empty dictionary
181+ if None.
182+ priors : dict, optional
183+ Dictionary of priors for the model. Defaults to None, in which
184+ case default priors are used.
178185 """
179186 super ().__init__ ()
180187 self .idata = None
@@ -224,8 +231,23 @@ def _data_setter(self, X: xr.DataArray) -> None:
224231 def fit (
225232 self , X : xr .DataArray , y : xr .DataArray , coords : Dict [str , Any ] | None = None
226233 ) -> az .InferenceData :
227- """Draw samples from posterior, prior predictive, and posterior predictive
228- distributions, placing them in the model's idata attribute.
234+ """Draw samples from posterior, prior predictive, and posterior
235+ predictive distributions.
236+
237+ Parameters
238+ ----------
239+ X : xr.DataArray
240+ Input features as an xarray DataArray.
241+ y : xr.DataArray
242+ Target variable as an xarray DataArray.
243+ coords : dict, optional
244+ Dictionary with coordinate names for named dimensions.
245+ Defaults to None.
246+
247+ Returns
248+ -------
249+ az.InferenceData
250+ InferenceData object containing the samples.
229251 """
230252
231253 # Ensure random_seed is used in sample_prior_predictive() and
@@ -356,6 +378,16 @@ def calculate_cumulative_impact(self, impact: xr.DataArray) -> xr.DataArray:
356378 def print_coefficients (
357379 self , labels : list [str ], round_to : int | None = None
358380 ) -> None :
381+ """Print the model coefficients with their labels.
382+
383+ Parameters
384+ ----------
385+ labels : list of str
386+ List of strings representing the coefficient names.
387+ round_to : int, optional
388+ Number of significant figures to round to. Defaults to None,
389+ in which case 2 significant figures are used.
390+ """
359391 if self .idata is None :
360392 raise RuntimeError ("Model has not been fit" )
361393
@@ -627,19 +659,27 @@ def build_model( # type: ignore
627659 coords : Dict [str , Any ],
628660 priors : Dict [str , Any ],
629661 ) -> None :
630- """Specify model with treatment regression and focal regression data and priors
631-
632- :param X: A pandas dataframe used to predict our outcome y
633- :param Z: A pandas dataframe used to predict our treatment variable t
634- :param y: An array of values representing our focal outcome y
635- :param t: An array of values representing the treatment t of
636- which we're interested in estimating the causal impact
637- :param coords: A dictionary with the coordinate names for our
638- instruments and covariates
639- :param priors: An optional dictionary of priors for the mus and
640- sigmas of both regressions
641- :code:`priors = {"mus": [0, 0], "sigmas": [1, 1],
642- "eta": 2, "lkj_sd": 2}`
662+ """Specify model with treatment regression and focal regression
663+ data and priors.
664+
665+ Parameters
666+ ----------
667+ X : np.ndarray
668+ Array used to predict our outcome y.
669+ Z : np.ndarray
670+ Array used to predict our treatment variable t.
671+ y : np.ndarray
672+ Array of values representing our focal outcome y.
673+ t : np.ndarray
674+ Array representing the treatment t of which we're interested
675+ in estimating the causal impact.
676+ coords : dict
677+ Dictionary with the coordinate names for our instruments and
678+ covariates.
679+ priors : dict
680+ Dictionary of priors for the mus and sigmas of both
681+ regressions. Example: ``priors = {"mus": [0, 0],
682+ "sigmas": [1, 1], "eta": 2, "lkj_sd": 2}``.
643683 """
644684
645685 # --- Priors ---
@@ -725,13 +765,33 @@ def fit( # type: ignore
725765 priors : Dict [str , Any ],
726766 ppc_sampler : str | None = None ,
727767 ) -> az .InferenceData :
728- """Draw samples from posterior distribution and potentially
729- from the prior and posterior predictive distributions. The
730- fit call can take values for the
731- ppc_sampler = ['jax', 'pymc', None]
732- We default to None, so the user can determine if they wish
733- to spend time sampling the posterior predictive distribution
734- independently.
768+ """Draw samples from posterior distribution and potentially from
769+ the prior and posterior predictive distributions.
770+
771+ Parameters
772+ ----------
773+ X : np.ndarray
774+ Array used to predict our outcome y.
775+ Z : np.ndarray
776+ Array used to predict our treatment variable t.
777+ y : np.ndarray
778+ Array of values representing our focal outcome y.
779+ t : np.ndarray
780+ Array representing the treatment variable.
781+ coords : dict
782+ Dictionary with coordinate names for named dimensions.
783+ priors : dict
784+ Dictionary of priors for the model.
785+ ppc_sampler : str, optional
786+ Sampler for posterior predictive distribution. Can be 'jax',
787+ 'pymc', or None. Defaults to None, so the user can determine
788+ if they wish to spend time sampling the posterior predictive
789+ distribution independently.
790+
791+ Returns
792+ -------
793+ az.InferenceData
794+ InferenceData object containing the samples.
735795 """
736796
737797 # Ensure random_seed is used in sample_prior_predictive() and
0 commit comments