diff --git a/pymc_extras/statespace/models/structural/components/autoregressive.py b/pymc_extras/statespace/models/structural/components/autoregressive.py index 5b0ee4e0..01074ccc 100644 --- a/pymc_extras/statespace/models/structural/components/autoregressive.py +++ b/pymc_extras/statespace/models/structural/components/autoregressive.py @@ -20,8 +20,8 @@ class AutoregressiveComponent(Component): name: str, default "auto_regressive" A name for this autoregressive component. Used to label dimensions and coordinates. - observed_state_names: list[str] | None, default None - List of strings for observed state labels. If None, defaults to ["data"]. + observed_state_names: Sequence[str] | NDArray | None, default None + Sequence of strings for observed state labels. If None, defaults to ["data"]. Notes ----- diff --git a/pymc_extras/statespace/models/structural/components/cycle.py b/pymc_extras/statespace/models/structural/components/cycle.py index 7c10d152..cf980c4d 100644 --- a/pymc_extras/statespace/models/structural/components/cycle.py +++ b/pymc_extras/statespace/models/structural/components/cycle.py @@ -39,7 +39,7 @@ class CycleComponent(Component): parameter, ``sigma_{name}`` will be added to the model. For multivariate time series, this is a vector (variable-specific innovation variances). - observed_state_names: list[str], optional + observed_state_names: Sequence[str] | NDArray, optional Names of the observed state variables. For univariate time series, defaults to ``["data"]``. For multivariate time series, specify a list of names for each endogenous variable. diff --git a/pymc_extras/statespace/models/structural/components/level_trend.py b/pymc_extras/statespace/models/structural/components/level_trend.py index ba44c706..a67ac3e3 100644 --- a/pymc_extras/statespace/models/structural/components/level_trend.py +++ b/pymc_extras/statespace/models/structural/components/level_trend.py @@ -25,8 +25,8 @@ class LevelTrendComponent(Component): name : str, default "level_trend" A name for this level-trend component. Used to label dimensions and coordinates. - observed_state_names : list[str] | None, default None - List of strings for observed state labels. If None, defaults to ["data"]. + observed_state_names : Sequence[str] | NDArray | None, default None + Sequence of strings for observed state labels. If None, defaults to ["data"]. Notes ----- diff --git a/pymc_extras/statespace/models/structural/components/measurement_error.py b/pymc_extras/statespace/models/structural/components/measurement_error.py index babac032..67dc0239 100644 --- a/pymc_extras/statespace/models/structural/components/measurement_error.py +++ b/pymc_extras/statespace/models/structural/components/measurement_error.py @@ -15,7 +15,7 @@ class MeasurementError(Component): ---------- name : str, optional Name of the measurement error component. Default is "MeasurementError". - observed_state_names : list[str] | None, optional + observed_state_names : Sequence[str] | NDArray | None, optional Names of the observed variables. If None, defaults to ["data"]. Notes diff --git a/pymc_extras/statespace/models/structural/components/regression.py b/pymc_extras/statespace/models/structural/components/regression.py index 89d26018..a5c57f56 100644 --- a/pymc_extras/statespace/models/structural/components/regression.py +++ b/pymc_extras/statespace/models/structural/components/regression.py @@ -24,8 +24,8 @@ class RegressionComponent(Component): k_exog. If None and k_exog is provided, coefficients will be named "{name}_1, {name}_2, ...". - observed_state_names : list[str] | None, default None - List of strings for observed state labels. If None, defaults to ["data"]. + observed_state_names : Sequence[str] | NDArray | None, default None + Sequence of strings for observed state labels. If None, defaults to ["data"]. innovations : bool, default False Whether to include stochastic innovations in the regression coefficients, diff --git a/pymc_extras/statespace/models/structural/components/seasonality.py b/pymc_extras/statespace/models/structural/components/seasonality.py index 1d103fbf..0c5d00cb 100644 --- a/pymc_extras/statespace/models/structural/components/seasonality.py +++ b/pymc_extras/statespace/models/structural/components/seasonality.py @@ -36,8 +36,8 @@ class TimeSeasonality(Component): included in the model, but it will not be identified -- you will need to handle this in the priors (e.g. with ZeroSumNormal). - observed_state_names: list[str] | None, default None - List of strings for observed state labels. If None, defaults to ["data"]. + observed_state_names: Sequence[str] | NDArray | None, default None + Sequence of strings for observed state labels. If None, defaults to ["data"]. Notes ----- diff --git a/pymc_extras/statespace/models/structural/core.py b/pymc_extras/statespace/models/structural/core.py index d90065ea..a8f56897 100644 --- a/pymc_extras/statespace/models/structural/core.py +++ b/pymc_extras/statespace/models/structural/core.py @@ -750,6 +750,99 @@ def _combine_property(self, other, name, allow_duplicates=True): new_prop.update(getattr(other, name)) return new_prop + def _combine_property_2(self, other, name, allow_duplicates=True): + """ + Combine a property from two components during component addition. + + This method handles the merging of component properties when two structural + components are combined using the `+` operator. It handles different data types + appropriately and provides control over duplicate handling for list properties. + + Parameters + ---------- + other : Component + The other component whose property is being combined with this one. + name : str + The name of the property to combine (e.g., 'state_names', 'param_names'). + allow_duplicates : bool, default True + Controls duplicate handling for list properties: + - True: Concatenates lists directly, preserving duplicates + - False: Adds only items from `other` that aren't already in `self` + + Returns + ------- + Any + Combined property value with type depending on the property type: + - list: Concatenated lists (with or without deduplication) + - dict: Merged dictionaries (other overwrites self for same keys) + - scalar/array: Single value (if identical) or error (if different) + + Raises + ------ + ValueError + When scalar properties have different non-None values that cannot be + automatically combined, indicating unclear user intent. + """ + self_prop = getattr(self, name) + other_prop = getattr(other, name) + + if isinstance(self_prop, list) and allow_duplicates: + return self_prop + other_prop + elif isinstance(self_prop, list) and not allow_duplicates: + return self_prop + [x for x in other_prop if x not in self_prop] + elif isinstance(self_prop, dict): + new_prop = self_prop.copy() + new_prop.update(other_prop) + return new_prop + else: + # NEW: Handle cases where self_prop is not a list or dict + import numpy as np + + # Handle numpy arrays specially + if isinstance(self_prop, np.ndarray) and isinstance(other_prop, np.ndarray): + if np.array_equal(self_prop, other_prop): + return self_prop + else: + # Convert to list for combination when arrays are different + return ( + list(self_prop) + [x for x in other_prop if x not in self_prop] + if not allow_duplicates + else list(self_prop) + list(other_prop) + ) + elif isinstance(self_prop, np.ndarray) or isinstance(other_prop, np.ndarray): + # One is array, one is not - convert to list + self_list = ( + list(self_prop) + if isinstance(self_prop, np.ndarray) + else [self_prop] + if self_prop is not None + else [] + ) + other_list = ( + list(other_prop) + if isinstance(other_prop, np.ndarray) + else [other_prop] + if other_prop is not None + else [] + ) + return ( + self_list + [x for x in other_list if x not in self_list] + if not allow_duplicates + else self_list + other_list + ) + elif self_prop == other_prop: + return self_prop + elif self_prop is None and other_prop is not None: + return other_prop + elif self_prop is not None and other_prop is None: + return self_prop + else: + # Different non-None values - this might indicate a problem + raise ValueError( + f"Cannot combine property '{name}': component values are different " + f"({self_prop} vs {other_prop}) and cannot be automatically combined" + ) + def _combine_component_info(self, other): combined_info = {} for key, value in self._component_info.items():