|
18 | 18 | broadcast_params, |
19 | 19 | normalize_size_param, |
20 | 20 | ) |
| 21 | +from pytensor.tensor.utils import faster_broadcast_to, faster_ndindex |
21 | 22 |
|
22 | 23 |
|
23 | 24 | # Scipy.stats is considerably slow to import |
@@ -976,19 +977,13 @@ def __call__(self, alphas, size=None, **kwargs): |
976 | 977 | @classmethod |
977 | 978 | def rng_fn(cls, rng, alphas, size): |
978 | 979 | if alphas.ndim > 1: |
979 | | - if size is None: |
980 | | - size = () |
981 | | - |
982 | | - size = tuple(np.atleast_1d(size)) |
983 | | - |
984 | | - if size: |
985 | | - alphas = np.broadcast_to(alphas, size + alphas.shape[-1:]) |
| 980 | + if size is not None: |
| 981 | + alphas = faster_broadcast_to(alphas, size + alphas.shape[-1:]) |
986 | 982 |
|
987 | 983 | samples_shape = alphas.shape |
988 | 984 | samples = np.empty(samples_shape) |
989 | | - for index in np.ndindex(*samples_shape[:-1]): |
| 985 | + for index in faster_ndindex(*samples_shape[:-1]): |
990 | 986 | samples[index] = rng.dirichlet(alphas[index]) |
991 | | - |
992 | 987 | return samples |
993 | 988 | else: |
994 | 989 | return rng.dirichlet(alphas, size=size) |
@@ -1800,11 +1795,11 @@ def rng_fn(cls, rng, n, p, size): |
1800 | 1795 | if size is None: |
1801 | 1796 | n, p = broadcast_params([n, p], [0, 1]) |
1802 | 1797 | else: |
1803 | | - n = np.broadcast_to(n, size) |
1804 | | - p = np.broadcast_to(p, size + p.shape[-1:]) |
| 1798 | + n = faster_broadcast_to(n, size) |
| 1799 | + p = faster_broadcast_to(p, size + p.shape[-1:]) |
1805 | 1800 |
|
1806 | 1801 | res = np.empty(p.shape, dtype=cls.dtype) |
1807 | | - for idx in np.ndindex(p.shape[:-1]): |
| 1802 | + for idx in faster_ndindex(p.shape[:-1]): |
1808 | 1803 | res[idx] = rng.multinomial(n[idx], p[idx]) |
1809 | 1804 | return res |
1810 | 1805 | else: |
@@ -1978,13 +1973,13 @@ def rng_fn(self, *params): |
1978 | 1973 | p.shape[:batch_ndim], |
1979 | 1974 | ) |
1980 | 1975 |
|
1981 | | - a = np.broadcast_to(a, size + a.shape[batch_ndim:]) |
| 1976 | + a = faster_broadcast_to(a, size + a.shape[batch_ndim:]) |
1982 | 1977 | if p is not None: |
1983 | | - p = np.broadcast_to(p, size + p.shape[batch_ndim:]) |
| 1978 | + p = faster_broadcast_to(p, size + p.shape[batch_ndim:]) |
1984 | 1979 |
|
1985 | 1980 | a_indexed_shape = a.shape[len(size) + 1 :] |
1986 | 1981 | out = np.empty(size + core_shape + a_indexed_shape, dtype=a.dtype) |
1987 | | - for idx in np.ndindex(size): |
| 1982 | + for idx in faster_ndindex(size): |
1988 | 1983 | out[idx] = rng.choice( |
1989 | 1984 | a[idx], p=None if p is None else p[idx], size=core_shape, replace=False |
1990 | 1985 | ) |
@@ -2097,10 +2092,10 @@ def rng_fn(self, rng, x, size): |
2097 | 2092 | if size is None: |
2098 | 2093 | size = x.shape[:batch_ndim] |
2099 | 2094 | else: |
2100 | | - x = np.broadcast_to(x, size + x.shape[batch_ndim:]) |
| 2095 | + x = faster_broadcast_to(x, size + x.shape[batch_ndim:]) |
2101 | 2096 |
|
2102 | 2097 | out = np.empty(size + x.shape[batch_ndim:], dtype=x.dtype) |
2103 | | - for idx in np.ndindex(size): |
| 2098 | + for idx in faster_ndindex(size): |
2104 | 2099 | out[idx] = rng.permutation(x[idx]) |
2105 | 2100 | return out |
2106 | 2101 |
|
|
0 commit comments