|
50 | 50 |
|
51 | 51 | _log = logging.getLogger(__name__)
|
52 | 52 |
|
| 53 | + |
| 54 | +RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False |
| 55 | + |
| 56 | + |
53 | 57 | # random variable object ...
|
54 | 58 | Var = Any
|
55 | 59 |
|
56 | 60 |
|
| 61 | +def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs): |
| 62 | + safe_coords = coords |
| 63 | + |
| 64 | + if not RAISE_ON_INCOMPATIBLE_COORD_LENGTHS: |
| 65 | + coords_lengths = {k: len(v) for k, v in coords.items()} |
| 66 | + for var_name, var in vars_dict.items(): |
| 67 | + # Iterate in reversed because of chain/draw batch dimensions |
| 68 | + for dim, dim_length in zip(reversed(dims.get(var_name, ())), reversed(var.shape)): |
| 69 | + coord_length = coords_lengths.get(dim, None) |
| 70 | + if (coord_length is not None) and (coord_length != dim_length): |
| 71 | + warnings.warn( |
| 72 | + f"Incompatible coordinate length of {coord_length} for dimension '{dim}' of variable '{var_name}'.\n" |
| 73 | + "This usually happens when a sliced or concatenated variable is wrapped as a `pymc.dims.Deterministic`." |
| 74 | + "The originate coordinates for this dim will not be included in the returned dataset for any of the variables. " |
| 75 | + "Instead they will default to `np.arange(var_length)` and the shorter variables will be right-padded with nan.\n" |
| 76 | + "To make this warning into an error set `pymc.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS` to `True`", |
| 77 | + UserWarning, |
| 78 | + ) |
| 79 | + if safe_coords is coords: |
| 80 | + safe_coords = coords.copy() |
| 81 | + safe_coords.pop(dim) |
| 82 | + coords_lengths.pop(dim) |
| 83 | + |
| 84 | + # FIXME: Would be better to drop coordinates altogether, but arviz defaults to `np.arange(var_length)` |
| 85 | + return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs) |
| 86 | + |
| 87 | + |
57 | 88 | def find_observations(model: "Model") -> dict[str, Var]:
|
58 | 89 | """If there are observations available, return them as a dictionary."""
|
59 | 90 | observations = {}
|
@@ -366,7 +397,7 @@ def priors_to_xarray(self):
|
366 | 397 | priors_dict[group] = (
|
367 | 398 | None
|
368 | 399 | if var_names is None
|
369 |
| - else dict_to_dataset( |
| 400 | + else dict_to_dataset_drop_incompatible_coords( |
370 | 401 | {k: np.expand_dims(self.prior[k], 0) for k in var_names},
|
371 | 402 | library=pymc,
|
372 | 403 | coords=self.coords,
|
|
0 commit comments