Skip to content

Commit 5be18bb

Browse files
committed
add description of round_to kwarg to docstrings
1 parent dcb913f commit 5be18bb

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

causalpy/pymc_experiments.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,9 @@ def _input_validation(self, data, treatment_time):
234234
def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
235235
"""
236236
Plot the results
237+
238+
:param round_to:
239+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
237240
"""
238241
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
239242

@@ -419,7 +422,11 @@ class SyntheticControl(PrePostFit):
419422
expt_type = "Synthetic Control"
420423

421424
def plot(self, plot_predictors=False, **kwargs):
422-
"""Plot the results"""
425+
"""Plot the results
426+
427+
:param round_to:
428+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
429+
"""
423430
fig, ax = super().plot(counterfactual_label="Synthetic control", **kwargs)
424431
if plot_predictors:
425432
# plot control units as well
@@ -585,7 +592,9 @@ def _input_validation(self):
585592

586593
def plot(self, round_to=None):
587594
"""Plot the results.
588-
Creating the combined mean + HDI legend entries is a bit involved.
595+
596+
:param round_to:
597+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
589598
"""
590599
fig, ax = plt.subplots()
591600

@@ -902,6 +911,9 @@ def _is_treated(self, x):
902911
def plot(self, round_to=None):
903912
"""
904913
Plot the results
914+
915+
:param round_to:
916+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
905917
"""
906918
fig, ax = plt.subplots()
907919
# Plot raw data
@@ -1116,6 +1128,9 @@ def _is_treated(self, x):
11161128
def plot(self, round_to=None):
11171129
"""
11181130
Plot the results
1131+
1132+
:param round_to:
1133+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
11191134
"""
11201135
fig, ax = plt.subplots()
11211136
# Plot raw data
@@ -1305,7 +1320,11 @@ def _input_validation(self) -> None:
13051320
)
13061321

13071322
def plot(self, round_to=None):
1308-
"""Plot the results"""
1323+
"""Plot the results
1324+
1325+
:param round_to:
1326+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
1327+
"""
13091328
fig, ax = plt.subplots(
13101329
2, 1, figsize=(7, 9), gridspec_kw={"height_ratios": [3, 1]}
13111330
)

causalpy/skl_experiments.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,11 @@ def __init__(
116116
self.post_impact_cumulative = np.cumsum(self.post_impact)
117117

118118
def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
119-
"""Plot experiment results"""
119+
"""Plot experiment results
120+
121+
:param round_to:
122+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
123+
"""
120124
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
121125

122126
ax[0].plot(self.datapre.index, self.pre_y, "k.")
@@ -263,7 +267,11 @@ class SyntheticControl(PrePostFit):
263267
"""
264268

265269
def plot(self, plot_predictors=False, round_to=None, **kwargs):
266-
"""Plot the results"""
270+
"""Plot the results
271+
272+
:param round_to:
273+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
274+
"""
267275
fig, ax = super().plot(
268276
counterfactual_label="Synthetic control", round_to=round_to, **kwargs
269277
)
@@ -404,7 +412,11 @@ def __init__(
404412
self.causal_impact = self.y_pred_treatment[1] - self.y_pred_counterfactual[0]
405413

406414
def plot(self, round_to=None):
407-
"""Plot results"""
415+
"""Plot results
416+
417+
:param round_to:
418+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
419+
"""
408420
fig, ax = plt.subplots()
409421

410422
# Plot raw data
@@ -614,7 +626,11 @@ def _is_treated(self, x):
614626
return np.greater_equal(x, self.treatment_threshold)
615627

616628
def plot(self, round_to=None):
617-
"""Plot results"""
629+
"""Plot results
630+
631+
:param round_to:
632+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
633+
"""
618634
fig, ax = plt.subplots()
619635
# Plot raw data
620636
sns.scatterplot(

0 commit comments

Comments
 (0)