Skip to content

Commit 7663c08

Browse files
committed
mypy
1 parent acf9ef1 commit 7663c08

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

pymc_marketing/mmm/mmm.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -667,16 +667,16 @@ def _build_control_contribution(self) -> pt.TensorVariable | None:
667667
pt.TensorVariable | None
668668
Control contribution variable or None if no controls
669669
"""
670+
if self.control_columns is None or len(self.control_columns) == 0:
671+
return None
672+
670673
X_data = self.preprocessed_data["X"]
671-
has_controls = (
672-
self.control_columns is not None
673-
and len(self.control_columns) > 0
674-
and isinstance(X_data, pd.DataFrame)
675-
and all(column in X_data.columns for column in self.control_columns)
676-
)
674+
if not isinstance(X_data, pd.DataFrame):
675+
raise TypeError("X data must be a DataFrame for control contribution")
677676

678-
if not has_controls:
679-
return None
677+
if not all(column in X_data.columns for column in self.control_columns):
678+
missing_cols = set(self.control_columns) - set(X_data.columns)
679+
raise ValueError(f"Control columns {missing_cols} not found in X data")
680680

681681
if self.model_config["gamma_control"].dims != ("control",):
682682
self.model_config["gamma_control"].dims = "control"
@@ -858,10 +858,12 @@ def build_model( # type: ignore[override]
858858
if not isinstance(X_data, pd.DataFrame):
859859
raise TypeError("X data must be a DataFrame")
860860

861+
y_data = self.preprocessed_data["y"]
862+
if not isinstance(y_data, (pd.Series, np.ndarray)):
863+
raise TypeError("y data must be a Series or ndarray")
864+
861865
channel_data_scaled, target_scaled, _, target_scale_ = (
862-
self._create_scaled_data_variables(
863-
X_data[self.channel_columns], self.preprocessed_data["y"]
864-
)
866+
self._create_scaled_data_variables(X_data[self.channel_columns], y_data)
865867
)
866868

867869
# Create time index if needed
@@ -1152,8 +1154,13 @@ def _prepare_target_data(
11521154
Dictionary with target data ready for pm.set_data
11531155
"""
11541156
if y is None:
1155-
dtype = self.preprocessed_data["y"].dtype # type: ignore
1156-
y_data = np.zeros(n_rows, dtype=dtype)
1157+
# When y is None, create zeros array matching the type of preprocessed y
1158+
y_preprocessed = self.preprocessed_data["y"]
1159+
if isinstance(y_preprocessed, (pd.Series, np.ndarray)):
1160+
y_data = np.zeros(n_rows, dtype=np.asarray(y_preprocessed).dtype)
1161+
else:
1162+
# Default to float64 if type is unknown
1163+
y_data = np.zeros(n_rows, dtype="float64")
11571164
elif isinstance(y, pd.Series):
11581165
y_data = y.to_numpy()
11591166
elif isinstance(y, np.ndarray):

0 commit comments

Comments
 (0)