Skip to content

Commit b67196e

Browse files
authored
Handle 0-dim channel_xr in create_zero_dataset (#1890)
* Handle 0-dim channel_xr in create_zero_dataset Update create_zero_dataset to support channel_xr Datasets with no dimensions by broadcasting scalar values for each channel across all rows. Add tests to verify correct behavior when channel_xr is 0-dim and when some channels are missing. * Test missing lines
1 parent 670120a commit b67196e

File tree

2 files changed

+160
-18
lines changed

2 files changed

+160
-18
lines changed

pymc_marketing/mmm/utils.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -349,26 +349,41 @@ def create_zero_dataset(
349349
if date_col in channel_xr.dims:
350350
raise ValueError("`channel_xr` must NOT include the date dimension.")
351351

352-
# --- 4.3 Convert to DataFrame & merge ----------------------------------
353-
channel_df = channel_xr.to_dataframe().reset_index()
354-
355-
# Left-join on every dimension; suffix prevents collisions during merge
356-
pred_df = pred_df.merge(
357-
channel_df,
358-
on=dim_cols,
359-
how="left",
360-
suffixes=("", "_chan"),
361-
)
352+
# --- 4.3 Inject constants ----------------------------------------------
353+
# Special-case: when there are NO dims (e.g., only channel dimension in the
354+
# allocation which was pivoted into variables), xarray can't create an index
355+
# for to_dataframe(). In this scenario, simply broadcast scalar values
356+
# across all rows.
357+
if len(channel_xr.dims) == 0:
358+
for ch in channel_cols:
359+
if ch in channel_xr.data_vars:
360+
# assign scalar value across all rows
361+
try:
362+
pred_df[ch] = channel_xr[ch].item()
363+
except Exception:
364+
pred_df[ch] = channel_xr[ch].values
365+
else:
366+
# Convert to DataFrame & merge when dims are present
367+
channel_df = channel_xr.to_dataframe().reset_index()
368+
369+
# Left-join on every dimension; suffix prevents collisions during merge
370+
pred_df = pred_df.merge(
371+
channel_df,
372+
on=dim_cols,
373+
how="left",
374+
suffixes=("", "_chan"),
375+
)
362376

363377
# --- 4.4 Copy merged values into official channel columns --------------
364-
for ch in channel_cols:
365-
chan_col = f"{ch}_chan"
366-
if chan_col in pred_df.columns:
367-
pred_df[ch] = pred_df[chan_col]
368-
pred_df.drop(columns=chan_col, inplace=True)
369-
370-
# Replace any remaining NaNs introduced by the merge
371-
pred_df[channel_cols] = pred_df[channel_cols].fillna(0.0)
378+
if len(channel_xr.dims) != 0:
379+
for ch in channel_cols:
380+
chan_col = f"{ch}_chan"
381+
if chan_col in pred_df.columns:
382+
pred_df[ch] = pred_df[chan_col]
383+
pred_df.drop(columns=chan_col, inplace=True)
384+
385+
# Replace any remaining NaNs introduced by the merge
386+
pred_df[channel_cols] = pred_df[channel_cols].fillna(0.0)
372387

373388
# ---- 5. Bring in any “other” columns from the training data ----------------
374389
other_cols = [

tests/mmm/test_utils.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,47 @@ def test_create_zero_dataset_error_cases(self):
507507
):
508508
create_zero_dataset(model, start_date, end_date, date_dim_xr)
509509

510+
def test_create_zero_dataset_channel_xr_includes_date_specific_error(self):
511+
"""Ensure we hit the explicit date-dimension error when date is an allowed model dim."""
512+
513+
class FakeMMM_DateDim:
514+
def __init__(self):
515+
dates = pd.date_range("2022-01-01", "2022-01-10", freq="D")
516+
self.X = pd.DataFrame(
517+
{
518+
"date": dates,
519+
"channel1": np.random.rand(10) * 10,
520+
"channel2": np.random.rand(10) * 5,
521+
}
522+
)
523+
self.date_column = "date"
524+
self.channel_columns = ["channel1", "channel2"]
525+
self.control_columns = []
526+
# Include 'date' as a model dim so the invalid-dims check passes,
527+
# and we can assert on the specific date-dimension error.
528+
self.dims = ["date"]
529+
530+
class FakeAdstock:
531+
l_max = 1
532+
533+
self.adstock = FakeAdstock()
534+
535+
model = FakeMMM_DateDim()
536+
start_date = "2022-02-01"
537+
end_date = "2022-02-03"
538+
539+
channel_with_date = xr.Dataset(
540+
data_vars={
541+
"channel1": ("date", np.array([1.0, 2.0])),
542+
},
543+
coords={"date": pd.date_range("2022-01-01", periods=2, freq="D")},
544+
)
545+
546+
with pytest.raises(
547+
ValueError, match=r"`channel_xr` must NOT include the date dimension\."
548+
):
549+
create_zero_dataset(model, start_date, end_date, channel_with_date)
550+
510551
def test_create_zero_dataset_no_dims(self):
511552
"""Test create_zero_dataset with a model that has no dimensions."""
512553

@@ -549,3 +590,89 @@ def test_create_zero_dataset_empty_date_range_error(self):
549590

550591
with pytest.raises(ValueError, match="Generated date range is empty"):
551592
create_zero_dataset(model, start_date, end_date)
593+
594+
def test_create_zero_dataset_channel_xr_no_dims_all_channels(self):
595+
"""Channel-only allocation: channel_xr is a 0-dim Dataset with per-channel scalars."""
596+
597+
class FakeMMM_NoDims:
598+
def __init__(self):
599+
dates = pd.date_range("2022-01-01", "2022-01-10", freq="D")
600+
self.X = pd.DataFrame(
601+
{
602+
"date": dates,
603+
"channel1": np.random.rand(10) * 10,
604+
"channel2": np.random.rand(10) * 5,
605+
}
606+
)
607+
self.date_column = "date"
608+
self.channel_columns = ["channel1", "channel2"]
609+
self.control_columns = []
610+
self.dims = [] # No dimensions
611+
612+
class FakeAdstock:
613+
l_max = 3
614+
615+
self.adstock = FakeAdstock()
616+
617+
model = FakeMMM_NoDims()
618+
start_date = "2022-02-01"
619+
end_date = "2022-02-05"
620+
621+
# 0-dim Dataset: variables are channels with scalar values
622+
channel_values = xr.Dataset(
623+
data_vars={
624+
"channel1": 100.0,
625+
"channel2": 200.0,
626+
}
627+
)
628+
629+
result = create_zero_dataset(model, start_date, end_date, channel_values)
630+
631+
# (5 + 3) days = 8 rows
632+
assert len(result) == 8
633+
assert np.all(result["channel1"] == 100.0)
634+
assert np.all(result["channel2"] == 200.0)
635+
636+
def test_create_zero_dataset_channel_xr_no_dims_missing_channel(self):
637+
"""Channel-only allocation with missing channel var should warn and leave others at 0."""
638+
639+
class FakeMMM_NoDims:
640+
def __init__(self):
641+
dates = pd.date_range("2022-01-01", "2022-01-10", freq="D")
642+
self.X = pd.DataFrame(
643+
{
644+
"date": dates,
645+
"channel1": np.random.rand(10) * 10,
646+
"channel2": np.random.rand(10) * 5,
647+
}
648+
)
649+
self.date_column = "date"
650+
self.channel_columns = ["channel1", "channel2"]
651+
self.control_columns = []
652+
self.dims = []
653+
654+
class FakeAdstock:
655+
l_max = 2
656+
657+
self.adstock = FakeAdstock()
658+
659+
model = FakeMMM_NoDims()
660+
start_date = "2022-02-01"
661+
end_date = "2022-02-03"
662+
663+
# Provide only one channel as scalar variable in 0-dim Dataset
664+
channel_values = xr.Dataset(
665+
data_vars={
666+
"channel1": 50.0,
667+
}
668+
)
669+
670+
with pytest.warns(
671+
UserWarning, match="does not supply values for \\['channel2'\\]"
672+
):
673+
result = create_zero_dataset(model, start_date, end_date, channel_values)
674+
675+
# (3 + 2) days = 5 rows
676+
assert len(result) == 5
677+
assert np.all(result["channel1"] == 50.0)
678+
assert np.all(result["channel2"] == 0.0)

0 commit comments

Comments
 (0)