Skip to content

Commit 290fcf4

Browse files
committed
.more stuff
1 parent 04bc0ac commit 290fcf4

File tree

6 files changed

+671
-156
lines changed

6 files changed

+671
-156
lines changed

docs/source/learn/core_notebooks/dims_module.ipynb

Lines changed: 628 additions & 149 deletions
Large diffs are not rendered by default.

pymc/backends/arviz.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,37 @@
4949

5050
_log = logging.getLogger(__name__)
5151

52+
53+
RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False
54+
55+
5256
# random variable object ...
5357
Var = Any
5458

59+
def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs):
60+
safe_coords = coords
61+
62+
if not RAISE_ON_INCOMPATIBLE_COORD_LENGTHS:
63+
coords_lengths = {k: len(v) for k, v in coords.items()}
64+
for var_name, var in vars_dict.items():
65+
# Iterate in reversed because of chain/draw batch dimensions
66+
for dim, dim_length in zip(reversed(dims[var_name]), reversed(var.shape)):
67+
coord_length = coords_lengths.get(dim, None)
68+
if (coord_length is not None) and (coord_length != dim_length):
69+
warnings.warn(
70+
f"Incompatible coordinate length of {coord_length} found for dimension {dim} of variable {var_name}.\n"
71+
"The originate coordinates for this dim will not be included in the returned dataset for any of the variables."
72+
"Instead they will default to `np.arange(var_length)`.\n"
73+
"To make this warning into an errror set `pymc.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS` to `True`",
74+
UserWarning,
75+
)
76+
if safe_coords is coords:
77+
safe_coords = coords.copy()
78+
safe_coords.pop(dim)
79+
coords_lengths.pop(dim)
80+
81+
# FIXME: Would be better to drop coordinates altogether, but arviz defaults to `np.arange(var_length)`
82+
return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs)
5583

5684
def find_observations(model: "Model") -> dict[str, Var]:
5785
"""If there are observations available, return them as a dictionary."""
@@ -365,7 +393,7 @@ def priors_to_xarray(self):
365393
priors_dict[group] = (
366394
None
367395
if var_names is None
368-
else dict_to_dataset(
396+
else dict_to_dataset_drop_incompatible_coords(
369397
{k: np.expand_dims(self.prior[k], 0) for k in var_names},
370398
library=pymc,
371399
coords=self.coords,

pymc/dims/transforms.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,13 @@ def extend_dim(array, dim):
6868
fill_val = norm - sum_vals / pt.sqrt(n)
6969

7070
out = ptx.concat([array, fill_val], dim=dim)
71+
import pytensor
72+
with pytensor.config.change_flags(optimizer_verbose=True):
73+
print(out.eval(mode="minimum_compile").shape)
7174
return out - norm
7275

7376
@staticmethod
74-
def reduct_dim(array, dim):
77+
def reduce_dim(array, dim):
7578
n = array.sizes[dim].astype("floatX")
7679
last = array.isel({dim: -1})
7780

@@ -81,7 +84,7 @@ def reduct_dim(array, dim):
8184

8285
def forward(self, value, *rv_inputs):
8386
for dim in self.dims:
84-
value = self.reduct_dim(value, dim=dim)
87+
value = self.reduce_dim(value, dim=dim)
8588
return value
8689

8790
def backward(self, value, *rv_inputs):
@@ -92,4 +95,4 @@ def backward(self, value, *rv_inputs):
9295
def log_jac_det(self, value, *rv_inputs):
9396
# Use following once broadcast_like is implemented
9497
# as_xtensor(0).broadcast_like(value, exclude=self.dims)`
95-
return (value * 0).sum(self.dims)
98+
return value.sum(self.dims) * 0

pymc/model/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1781,7 +1781,7 @@ def point_logps(self, point=None, round_vals=2, **kwargs):
17811781
point = self.initial_point()
17821782

17831783
factors = self.basic_RVs + self.potentials
1784-
factor_logps_fn = [pt.sum(factor) for factor in self.logp(factors, sum=False)]
1784+
factor_logps_fn = [factor.sum() for factor in self.logp(factors, sum=False)]
17851785
return {
17861786
factor.name: np.round(np.asarray(factor_logp), round_vals)
17871787
for factor, factor_logp in zip(

pymc/model_graph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,11 @@ def create_plate_label_with_dim_length(
7777
)
7878

7979

80+
from pymc.pytensorf import _cheap_eval_mode
81+
82+
8083
def fast_eval(var):
81-
return function([], var, mode="FAST_COMPILE")()
84+
return function([], var, mode=_cheap_eval_mode)()
8285

8386

8487
class NodeType(str, Enum):

tests/dims/test_distributions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def test_zerosumnormal():
6060
np.testing.assert_allclose(ip_value, ref_ip_value)
6161

6262
logp_fn = model.compile_logp()
63-
ref_logp_fn = ref_model.compile_logp()
6463
logp_fn(ip)
64+
ref_logp_fn = ref_model.compile_logp()
6565
# np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ref_ip))
6666
# Test a new
67+
68+
test_zerosumnormal()

0 commit comments

Comments
 (0)