|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -import collections |
16 | 15 | import functools
|
17 | 16 | import threading
|
18 | 17 | import types
|
|
65 | 64 | from pymc.distributions.transforms import _default_transform
|
66 | 65 | from pymc.exceptions import ImputationWarning, SamplingError, ShapeError, ShapeWarning
|
67 | 66 | from pymc.initial_point import make_initial_point_fn
|
68 |
| -from pymc.math import flatten_list |
69 | 67 | from pymc.util import (
|
70 | 68 | UNSET,
|
71 | 69 | WithMemoization,
|
|
86 | 84 | "compile_fn",
|
87 | 85 | ]
|
88 | 86 |
|
89 |
| -FlatView = collections.namedtuple("FlatView", "input, replacements") |
90 |
| - |
91 | 87 |
|
92 | 88 | class InstanceMethod:
|
93 | 89 | """Class for hiding references to instance methods so they can be pickled.
|
@@ -1663,49 +1659,6 @@ def profile(self, outs, *, n=1000, point=None, profile=True, **kwargs):
|
1663 | 1659 |
|
1664 | 1660 | return f.profile
|
1665 | 1661 |
|
1666 |
| - def flatten(self, vars=None, order=None, inputvar=None): |
1667 |
| - """Flattens model's input and returns: |
1668 |
| -
|
1669 |
| - Parameters |
1670 |
| - ---------- |
1671 |
| - vars: list of variables or None |
1672 |
| - if None, then all model.free_RVs are used for flattening input |
1673 |
| - order: list of variable names |
1674 |
| - Optional, use predefined ordering |
1675 |
| - inputvar: at.vector |
1676 |
| - Optional, use predefined inputvar |
1677 |
| -
|
1678 |
| - Returns |
1679 |
| - ------- |
1680 |
| - flat_view |
1681 |
| - """ |
1682 |
| - if vars is None: |
1683 |
| - vars = self.value_vars |
1684 |
| - if order is not None: |
1685 |
| - var_map = {v.name: v for v in vars} |
1686 |
| - vars = [var_map[n] for n in order] |
1687 |
| - |
1688 |
| - if inputvar is None: |
1689 |
| - inputvar = at.vector("flat_view", dtype=aesara.config.floatX) |
1690 |
| - if aesara.config.compute_test_value != "off": |
1691 |
| - if vars: |
1692 |
| - inputvar.tag.test_value = flatten_list(vars).tag.test_value |
1693 |
| - else: |
1694 |
| - inputvar.tag.test_value = np.asarray([], inputvar.dtype) |
1695 |
| - |
1696 |
| - replacements = {} |
1697 |
| - last_idx = 0 |
1698 |
| - for var in vars: |
1699 |
| - arr_len = at.prod(var.shape, dtype="int64") |
1700 |
| - replacements[self.named_vars[var.name]] = ( |
1701 |
| - inputvar[last_idx : (last_idx + arr_len)].reshape(var.shape).astype(var.dtype) |
1702 |
| - ) |
1703 |
| - last_idx += arr_len |
1704 |
| - |
1705 |
| - flat_view = FlatView(inputvar, replacements) |
1706 |
| - |
1707 |
| - return flat_view |
1708 |
| - |
1709 | 1662 | def update_start_vals(self, a: Dict[str, np.ndarray], b: Dict[str, np.ndarray]):
|
1710 | 1663 | r"""Update point `a` with `b`, without overwriting existing keys.
|
1711 | 1664 |
|
|
0 commit comments