diff --git a/Makefile b/Makefile index d109ae39..4fdfbf7c 100644 --- a/Makefile +++ b/Makefile @@ -13,10 +13,10 @@ check_lint: interrogate . doctest: - pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py + python -m pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py test: - pytest + python -m pytest uml: pyreverse -o png causalpy --output-directory docs/source/_static --ignore tests diff --git a/causalpy/__init__.py b/causalpy/__init__.py index 66031185..5587fb3e 100644 --- a/causalpy/__init__.py +++ b/causalpy/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import arviz as az import causalpy.pymc_models as pymc_models import causalpy.skl_models as skl_models @@ -28,8 +27,6 @@ from .experiments.regression_kink import RegressionKink from .experiments.synthetic_control import SyntheticControl -az.style.use("arviz-darkgrid") - __all__ = [ "__version__", "DifferenceInDifferences", diff --git a/causalpy/experiments/base.py b/causalpy/experiments/base.py index abc284c7..f24cc69a 100644 --- a/causalpy/experiments/base.py +++ b/causalpy/experiments/base.py @@ -17,6 +17,8 @@ from abc import abstractmethod +import arviz as az +import matplotlib.pyplot as plt import pandas as pd from sklearn.base import RegressorMixin @@ -63,12 +65,14 @@ def plot(self, *args, **kwargs) -> tuple: Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot` depending on the model type. """ - if isinstance(self.model, PyMCModel): - return self._bayesian_plot(*args, **kwargs) - elif isinstance(self.model, RegressorMixin): - return self._ols_plot(*args, **kwargs) - else: - raise ValueError("Unsupported model type") + # Apply arviz-darkgrid style only during plotting, then revert + with plt.style.context(az.style.library["arviz-darkgrid"]): + if isinstance(self.model, PyMCModel): + return self._bayesian_plot(*args, **kwargs) + elif isinstance(self.model, RegressorMixin): + return self._ols_plot(*args, **kwargs) + else: + raise ValueError("Unsupported model type") @abstractmethod def _bayesian_plot(self, *args, **kwargs): diff --git a/environment.yml b/environment.yml index a838de19..09850e05 100644 --- a/environment.yml +++ b/environment.yml @@ -16,3 +16,4 @@ dependencies: - statsmodels - xarray>=v2022.11.0 - pymc-extras>=0.3.0 + - python>=3.11