Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ check_lint:
interrogate .

doctest:
pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py
python -m pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py

test:
pytest
python -m pytest

uml:
pyreverse -o png causalpy --output-directory docs/source/_static --ignore tests
Expand Down
3 changes: 0 additions & 3 deletions causalpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import arviz as az

import causalpy.pymc_models as pymc_models
import causalpy.skl_models as skl_models
Expand All @@ -28,8 +27,6 @@
from .experiments.regression_kink import RegressionKink
from .experiments.synthetic_control import SyntheticControl

az.style.use("arviz-darkgrid")

__all__ = [
"__version__",
"DifferenceInDifferences",
Expand Down
17 changes: 11 additions & 6 deletions causalpy/experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,17 @@ def plot(self, *args, **kwargs) -> tuple:
Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot`
depending on the model type.
"""
if isinstance(self.model, PyMCModel):
return self._bayesian_plot(*args, **kwargs)
elif isinstance(self.model, RegressorMixin):
return self._ols_plot(*args, **kwargs)
else:
raise ValueError("Unsupported model type")
import arviz as az
import matplotlib.pyplot as plt

# Apply arviz-darkgrid style only during plotting, then revert
with plt.style.context(az.style.library["arviz-darkgrid"]):
if isinstance(self.model, PyMCModel):
return self._bayesian_plot(*args, **kwargs)
elif isinstance(self.model, RegressorMixin):
return self._ols_plot(*args, **kwargs)
else:
raise ValueError("Unsupported model type")

@abstractmethod
def _bayesian_plot(self, *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ dependencies:
- statsmodels
- xarray>=v2022.11.0
- pymc-extras>=0.3.0
- python>=3.11