Skip to content

Commit e4c7143

Browse files
authored
Add table method to ModelBuilder to display rich table (#1786)
1 parent 18b2688 commit e4c7143

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

pymc_marketing/model_builder.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818
import warnings
1919
from abc import ABC, abstractmethod
20+
from functools import wraps
2021
from inspect import signature
2122
from pathlib import Path
2223
from typing import Any
@@ -27,6 +28,8 @@
2728
import pymc as pm
2829
import xarray as xr
2930
from pymc.util import RandomState
31+
from pymc_extras.printing import model_table
32+
from rich.table import Table
3033

3134
from pymc_marketing.hsgp_kwargs import HSGPKwargs
3235
from pymc_marketing.utils import from_netcdf
@@ -103,6 +106,20 @@ def accessor(self) -> xr.Dataset:
103106
return property(accessor)
104107

105108

109+
def requires_model(func):
110+
"""Ensure that the model is built before calling a method."""
111+
112+
@wraps(func)
113+
def wrapper(self, *args, **kwargs):
114+
if not hasattr(self, "model"):
115+
raise RuntimeError(
116+
"The model hasn't been built yet. Please call `build_model` first."
117+
)
118+
return func(self, *args, **kwargs)
119+
120+
return wrapper
121+
122+
106123
def create_sample_kwargs(
107124
sampler_config: dict[str, Any],
108125
progressbar: bool | None,
@@ -1063,6 +1080,23 @@ def graphviz(self, **kwargs):
10631080
"""
10641081
return pm.model_to_graphviz(self.model, **kwargs)
10651082

1083+
@requires_model
1084+
def table(self, **model_table_kwargs) -> Table:
1085+
"""Get the summary table of the model.
1086+
1087+
Parameters
1088+
----------
1089+
**model_table_kwargs
1090+
Keyword arguments for the `model_table` function
1091+
1092+
Returns
1093+
-------
1094+
rich.table.Table
1095+
A rich table containing the summary of the model.
1096+
1097+
"""
1098+
return model_table(self.model, **model_table_kwargs)
1099+
10661100
prior = create_idata_accessor(
10671101
"prior",
10681102
"The model hasn't been sampled yet, call .sample_prior_predictive() first",

tests/test_model_builder.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import pymc as pm
2525
import pytest
2626
import xarray as xr
27+
from rich.table import Table
2728

2829
from pymc_marketing.model_builder import (
2930
DifferentModelError,
@@ -794,3 +795,13 @@ def test_load_from_idata_errors(request, fixture_name, match) -> None:
794795
idata = request.getfixturevalue(fixture_name)
795796
with pytest.raises(DifferentModelError, match=match):
796797
ModelBuilderTest.load_from_idata(idata)
798+
799+
800+
def test_table() -> None:
801+
model = ModelBuilderTest()
802+
match = "The model hasn't been built yet"
803+
with pytest.raises(RuntimeError, match=match):
804+
model.table()
805+
806+
model.build_model(pd.DataFrame({"input": [1, 2, 3]}), pd.Series([1, 2, 3]))
807+
assert isinstance(model.table(), Table)

0 commit comments

Comments
 (0)