Skip to content

Commit 61ad1ca

Browse files
committed
Remove deprecated model methods and reorganize remaining ones
1 parent 46d18b4 commit 61ad1ca

File tree

2 files changed

+47
-104
lines changed

2 files changed

+47
-104
lines changed

pymc/model.py

Lines changed: 39 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -688,13 +688,6 @@ def compile_d2logp(
688688
"""
689689
return self.model.compile_fn(self.d2logp(vars=vars, jacobian=jacobian))
690690

691-
def logpt(self, *args, **kwargs):
692-
warnings.warn(
693-
"Model.logpt has been deprecated. Use Model.logp instead.",
694-
FutureWarning,
695-
)
696-
return self.logp(*args, **kwargs)
697-
698691
def logp(
699692
self,
700693
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
@@ -769,13 +762,6 @@ def logp(
769762
logp_scalar.name = logp_scalar_name
770763
return logp_scalar
771764

772-
def dlogpt(self, *args, **kwargs):
773-
warnings.warn(
774-
"Model.dlogpt has been deprecated. Use Model.dlogp instead.",
775-
FutureWarning,
776-
)
777-
return self.dlogp(*args, **kwargs)
778-
779765
def dlogp(
780766
self,
781767
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
@@ -814,13 +800,6 @@ def dlogp(
814800
cost = self.logp(jacobian=jacobian)
815801
return gradient(cost, value_vars)
816802

817-
def d2logpt(self, *args, **kwargs):
818-
warnings.warn(
819-
"Model.d2logpt has been deprecated. Use Model.d2logp instead.",
820-
FutureWarning,
821-
)
822-
return self.d2logp(*args, **kwargs)
823-
824803
def d2logp(
825804
self,
826805
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
@@ -859,69 +838,29 @@ def d2logp(
859838
cost = self.logp(jacobian=jacobian)
860839
return hessian(cost, value_vars)
861840

862-
@property
863-
def datalogpt(self):
864-
warnings.warn(
865-
"Model.datalogpt has been deprecated. Use Model.datalogp instead.",
866-
FutureWarning,
867-
)
868-
return self.datalogp
869-
870841
@property
871842
def datalogp(self) -> Variable:
872843
"""Aesara scalar of log-probability of the observed variables and
873844
potential terms"""
874845
return self.observedlogp + self.potentiallogp
875846

876-
@property
877-
def varlogpt(self):
878-
warnings.warn(
879-
"Model.varlogpt has been deprecated. Use Model.varlogp instead.",
880-
FutureWarning,
881-
)
882-
return self.varlogp
883-
884847
@property
885848
def varlogp(self) -> Variable:
886849
"""Aesara scalar of log-probability of the unobserved random variables
887850
(excluding deterministic)."""
888851
return self.logp(vars=self.free_RVs)
889852

890-
@property
891-
def varlogp_nojact(self):
892-
warnings.warn(
893-
"Model.varlogp_nojact has been deprecated. Use Model.varlogp_nojac instead.",
894-
FutureWarning,
895-
)
896-
return self.varlogp_nojac
897-
898853
@property
899854
def varlogp_nojac(self) -> Variable:
900855
"""Aesara scalar of log-probability of the unobserved random variables
901856
(excluding deterministic) without jacobian term."""
902857
return self.logp(vars=self.free_RVs, jacobian=False)
903858

904-
@property
905-
def observedlogpt(self):
906-
warnings.warn(
907-
"Model.observedlogpt has been deprecated. Use Model.observedlogp instead.",
908-
FutureWarning,
909-
)
910-
return self.observedlogp
911-
912859
@property
913860
def observedlogp(self) -> Variable:
914861
"""Aesara scalar of log-probability of the observed variables"""
915862
return self.logp(vars=self.observed_RVs)
916863

917-
@property
918-
def potentiallogpt(self):
919-
warnings.warn(
920-
"Model.potentiallogpt has been deprecated. Use Model.potentiallogp instead.",
921-
FutureWarning,
922-
)
923-
return self.potentiallogp
924-
925864
@property
926865
def potentiallogp(self) -> Variable:
927866
"""Aesara scalar of log-probability of the Potential terms"""
@@ -933,14 +872,6 @@ def potentiallogp(self) -> Variable:
933872
else:
934873
return at.constant(0.0)
935874

936-
@property
937-
def vars(self):
938-
warnings.warn(
939-
"Model.vars has been deprecated. Use Model.value_vars instead.",
940-
FutureWarning,
941-
)
942-
return self.value_vars
943-
944875
@property
945876
def value_vars(self):
946877
"""List of unobserved random variables used as inputs to the model's
@@ -1013,6 +944,17 @@ def basic_RVs(self):
1013944
"""
1014945
return self.free_RVs + self.observed_RVs
1015946

947+
@property
948+
def unobserved_RVs(self):
949+
"""List of all random variables, including deterministic ones.
950+
951+
These are the actual random variable terms that make up the
952+
"sample-space" graph (i.e. you can sample these graphs by compiling them
953+
with `aesara.function`). If you want the corresponding log-likelihood terms,
954+
use `var.tag.value_var`.
955+
"""
956+
return self.free_RVs + self.deterministics
957+
1016958
@property
1017959
def RV_dims(self) -> Dict[str, Tuple[Union[str, None], ...]]:
1018960
"""Tuples of dimension names for specific model variables.
@@ -1318,6 +1260,34 @@ def set_data(
13181260

13191261
shared_object.set_value(values)
13201262

1263+
def initial_point(self, seed=None) -> Dict[str, np.ndarray]:
1264+
"""Computes the initial point of the model.
1265+
1266+
Returns
1267+
-------
1268+
ip : dict
1269+
Maps names of transformed variables to numeric initial values in the transformed space.
1270+
"""
1271+
fn = make_initial_point_fn(model=self, return_transformed=True)
1272+
return Point(fn(seed), model=self)
1273+
1274+
@property
1275+
def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]]:
1276+
"""Maps transformed variables to initial value placeholders.
1277+
1278+
Keys are the random variables (as returned by e.g. ``pm.Uniform()``) and
1279+
values are the numeric/symbolic initial values, strings denoting the strategy to get them, or None.
1280+
"""
1281+
return self._initial_values
1282+
1283+
def set_initval(self, rv_var, initval):
1284+
"""Sets an initial value (strategy) for a random variable."""
1285+
if initval is not None and not isinstance(initval, (Variable, str)):
1286+
# Convert scalars or array-like inputs to ndarrays
1287+
initval = rv_var.type.filter(initval)
1288+
1289+
self.initial_values[rv_var] = initval
1290+
13211291
def register_rv(
13221292
self, rv_var, name, data=None, total_size=None, dims=None, transform=UNSET, initval=None
13231293
):
@@ -1761,13 +1731,6 @@ def check_start_vals(self, start):
17611731
f"Initial evaluation results:\n{initial_eval}"
17621732
)
17631733

1764-
def check_test_point(self, *args, **kwargs):
1765-
warnings.warn(
1766-
"`Model.check_test_point` has been deprecated. Use `Model.point_logps` instead.",
1767-
FutureWarning,
1768-
)
1769-
return self.point_logps(*args, **kwargs)
1770-
17711734
def point_logps(self, point=None, round_vals=2):
17721735
"""Computes the log probability of `point` for all random variables in the model.
17731736

pymc/tests/test_model.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -482,15 +482,14 @@ def test_model_roundtrip(self):
482482
)
483483

484484

485-
def test_model_vars():
485+
def test_model_value_vars():
486486
with pm.Model() as model:
487487
a = pm.Normal("a")
488488
pm.Normal("x", a)
489489

490-
with pytest.warns(FutureWarning):
491-
old_vars = model.vars
492-
493-
assert old_vars == model.value_vars
490+
value_vars = model.value_vars
491+
assert len(value_vars) == 2
492+
assert set(value_vars) == set(pm.inputvars(model.logp()))
494493

495494

496495
def test_model_var_maps():
@@ -589,8 +588,7 @@ def test_point_logps():
589588
a = pm.Uniform("a")
590589
pm.Normal("x", a)
591590

592-
with pytest.warns(FutureWarning):
593-
logp_vals = model.check_test_point()
591+
logp_vals = model.point_logps()
594592

595593
assert "x" in logp_vals.keys()
596594
assert "a" in logp_vals.keys()
@@ -917,34 +915,16 @@ def test_set_data_constant_shape_error():
917915
pmodel.set_data("y", np.arange(10))
918916

919917

920-
def test_model_logpt_deprecation_warning():
918+
def test_model_deprecation_warning():
921919
with pm.Model() as m:
922920
x = pm.Normal("x", 0, 1, size=2)
923921
y = pm.LogNormal("y", 0, 1, size=2)
924922

925923
with pytest.warns(FutureWarning):
926-
m.logpt()
927-
928-
with pytest.warns(FutureWarning):
929-
m.dlogpt()
930-
931-
with pytest.warns(FutureWarning):
932-
m.d2logpt()
933-
934-
with pytest.warns(FutureWarning):
935-
m.datalogpt
936-
937-
with pytest.warns(FutureWarning):
938-
m.varlogpt
939-
940-
with pytest.warns(FutureWarning):
941-
m.observedlogpt
942-
943-
with pytest.warns(FutureWarning):
944-
m.potentiallogpt
924+
m.disc_vars
945925

946926
with pytest.warns(FutureWarning):
947-
m.varlogp_nojact
927+
m.cont_vars
948928

949929

950930
@pytest.mark.parametrize("jacobian", [True, False])

0 commit comments

Comments
 (0)