Skip to content

Commit d28164e

Browse files
committed
hide arviz style change in BaseExperiment.plot
1 parent ac95a60 commit d28164e

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

causalpy/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import arviz as az
1514

1615
import causalpy.pymc_models as pymc_models
1716
import causalpy.skl_models as skl_models
@@ -28,8 +27,6 @@
2827
from .experiments.regression_kink import RegressionKink
2928
from .experiments.synthetic_control import SyntheticControl
3029

31-
az.style.use("arviz-darkgrid")
32-
3330
__all__ = [
3431
"__version__",
3532
"DifferenceInDifferences",

causalpy/experiments/base.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,17 @@ def plot(self, *args, **kwargs) -> tuple:
6363
Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot`
6464
depending on the model type.
6565
"""
66-
if isinstance(self.model, PyMCModel):
67-
return self._bayesian_plot(*args, **kwargs)
68-
elif isinstance(self.model, RegressorMixin):
69-
return self._ols_plot(*args, **kwargs)
70-
else:
71-
raise ValueError("Unsupported model type")
66+
import arviz as az
67+
import matplotlib.pyplot as plt
68+
69+
# Apply arviz-darkgrid style only during plotting, then revert
70+
with plt.style.context(az.style.library["arviz-darkgrid"]):
71+
if isinstance(self.model, PyMCModel):
72+
return self._bayesian_plot(*args, **kwargs)
73+
elif isinstance(self.model, RegressorMixin):
74+
return self._ols_plot(*args, **kwargs)
75+
else:
76+
raise ValueError("Unsupported model type")
7277

7378
@abstractmethod
7479
def _bayesian_plot(self, *args, **kwargs):

0 commit comments

Comments
 (0)