Skip to content

Commit 4511792

Browse files
committed
Remove unused Model.flatten
1 parent 69719ee commit 4511792

File tree

1 file changed

+0
-47
lines changed

1 file changed

+0
-47
lines changed

pymc/model.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import collections
1615
import functools
1716
import threading
1817
import types
@@ -65,7 +64,6 @@
6564
from pymc.distributions.transforms import _default_transform
6665
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError, ShapeWarning
6766
from pymc.initial_point import make_initial_point_fn
68-
from pymc.math import flatten_list
6967
from pymc.util import (
7068
UNSET,
7169
WithMemoization,
@@ -86,8 +84,6 @@
8684
"compile_fn",
8785
]
8886

89-
FlatView = collections.namedtuple("FlatView", "input, replacements")
90-
9187

9288
class InstanceMethod:
9389
"""Class for hiding references to instance methods so they can be pickled.
@@ -1663,49 +1659,6 @@ def profile(self, outs, *, n=1000, point=None, profile=True, **kwargs):
16631659

16641660
return f.profile
16651661

1666-
def flatten(self, vars=None, order=None, inputvar=None):
1667-
"""Flattens model's input and returns:
1668-
1669-
Parameters
1670-
----------
1671-
vars: list of variables or None
1672-
if None, then all model.free_RVs are used for flattening input
1673-
order: list of variable names
1674-
Optional, use predefined ordering
1675-
inputvar: at.vector
1676-
Optional, use predefined inputvar
1677-
1678-
Returns
1679-
-------
1680-
flat_view
1681-
"""
1682-
if vars is None:
1683-
vars = self.value_vars
1684-
if order is not None:
1685-
var_map = {v.name: v for v in vars}
1686-
vars = [var_map[n] for n in order]
1687-
1688-
if inputvar is None:
1689-
inputvar = at.vector("flat_view", dtype=aesara.config.floatX)
1690-
if aesara.config.compute_test_value != "off":
1691-
if vars:
1692-
inputvar.tag.test_value = flatten_list(vars).tag.test_value
1693-
else:
1694-
inputvar.tag.test_value = np.asarray([], inputvar.dtype)
1695-
1696-
replacements = {}
1697-
last_idx = 0
1698-
for var in vars:
1699-
arr_len = at.prod(var.shape, dtype="int64")
1700-
replacements[self.named_vars[var.name]] = (
1701-
inputvar[last_idx : (last_idx + arr_len)].reshape(var.shape).astype(var.dtype)
1702-
)
1703-
last_idx += arr_len
1704-
1705-
flat_view = FlatView(inputvar, replacements)
1706-
1707-
return flat_view
1708-
17091662
def update_start_vals(self, a: Dict[str, np.ndarray], b: Dict[str, np.ndarray]):
17101663
r"""Update point `a` with `b`, without overwriting existing keys.
17111664

0 commit comments

Comments
 (0)