|
17 | 17 | import json
|
18 | 18 | import warnings
|
19 | 19 | from abc import ABC, abstractmethod
|
| 20 | +from functools import wraps |
20 | 21 | from inspect import signature
|
21 | 22 | from pathlib import Path
|
22 | 23 | from typing import Any
|
|
27 | 28 | import pymc as pm
|
28 | 29 | import xarray as xr
|
29 | 30 | from pymc.util import RandomState
|
| 31 | +from pymc_extras.printing import model_table |
| 32 | +from rich.table import Table |
30 | 33 |
|
31 | 34 | from pymc_marketing.hsgp_kwargs import HSGPKwargs
|
32 | 35 | from pymc_marketing.utils import from_netcdf
|
@@ -103,6 +106,20 @@ def accessor(self) -> xr.Dataset:
|
103 | 106 | return property(accessor)
|
104 | 107 |
|
105 | 108 |
|
| 109 | +def requires_model(func): |
| 110 | + """Ensure that the model is built before calling a method.""" |
| 111 | + |
| 112 | + @wraps(func) |
| 113 | + def wrapper(self, *args, **kwargs): |
| 114 | + if not hasattr(self, "model"): |
| 115 | + raise RuntimeError( |
| 116 | + "The model hasn't been built yet. Please call `build_model` first." |
| 117 | + ) |
| 118 | + return func(self, *args, **kwargs) |
| 119 | + |
| 120 | + return wrapper |
| 121 | + |
| 122 | + |
106 | 123 | def create_sample_kwargs(
|
107 | 124 | sampler_config: dict[str, Any],
|
108 | 125 | progressbar: bool | None,
|
@@ -1063,6 +1080,23 @@ def graphviz(self, **kwargs):
|
1063 | 1080 | """
|
1064 | 1081 | return pm.model_to_graphviz(self.model, **kwargs)
|
1065 | 1082 |
|
| 1083 | + @requires_model |
| 1084 | + def table(self, **model_table_kwargs) -> Table: |
| 1085 | + """Get the summary table of the model. |
| 1086 | +
|
| 1087 | + Parameters |
| 1088 | + ---------- |
| 1089 | + **model_table_kwargs |
| 1090 | + Keyword arguments for the `model_table` function |
| 1091 | +
|
| 1092 | + Returns |
| 1093 | + ------- |
| 1094 | + rich.table.Table |
| 1095 | + A rich table containing the summary of the model. |
| 1096 | +
|
| 1097 | + """ |
| 1098 | + return model_table(self.model, **model_table_kwargs) |
| 1099 | + |
1066 | 1100 | prior = create_idata_accessor(
|
1067 | 1101 | "prior",
|
1068 | 1102 | "The model hasn't been sampled yet, call .sample_prior_predictive() first",
|
|
0 commit comments