Skip to content

Commit 3d37a11

Browse files
committed
tests
1 parent 83ff985 commit 3d37a11

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pymc_marketing/mmm/mmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,7 @@ def build_model( # type: ignore[override]
857857
raise TypeError("X data must be a DataFrame")
858858

859859
y_data = self.preprocessed_data["y"]
860-
if not isinstance(y_data, (pd.Series, np.ndarray)):
860+
if not isinstance(y_data, pd.Series | np.ndarray):
861861
raise TypeError("y data must be a Series or ndarray")
862862

863863
channel_data_scaled, target_scaled, _, target_scale_ = (
@@ -1154,7 +1154,7 @@ def _prepare_target_data(
11541154
if y is None:
11551155
# When y is None, create zeros array matching the type of preprocessed y
11561156
y_preprocessed = self.preprocessed_data["y"]
1157-
if isinstance(y_preprocessed, (pd.Series, np.ndarray)):
1157+
if isinstance(y_preprocessed, pd.Series | np.ndarray):
11581158
y_data = np.zeros(n_rows, dtype=np.asarray(y_preprocessed).dtype)
11591159
else:
11601160
# Default to float64 if type is unknown

tests/mmm/test_mmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2050,7 +2050,7 @@ def test_compute_scales(self, toy_X: pd.DataFrame, toy_y: pd.Series):
20502050
assert hasattr(mmm, "channel_scale")
20512051
assert hasattr(mmm, "target_scale")
20522052
assert len(mmm.channel_scale) == len(mmm.channel_columns)
2053-
assert isinstance(mmm.target_scale, (float, np.floating))
2053+
assert isinstance(mmm.target_scale, float | np.floating)
20542054

20552055
def test_build_intercept_static(self, toy_X: pd.DataFrame, toy_y: pd.Series):
20562056
"""Test _build_intercept with static intercept."""

0 commit comments

Comments
 (0)