Skip to content

Commit 3c651a0

Browse files
committed
optimization
1 parent 3d37a11 commit 3c651a0

File tree

1 file changed

+108
-87
lines changed

1 file changed

+108
-87
lines changed

pymc_marketing/mmm/mmm.py

Lines changed: 108 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464

6565
__all__ = ["MMM", "BaseMMM"]
6666

67-
# Constants
6867
DEFAULT_HDI_PROB = 0.94
6968

7069

@@ -345,23 +344,21 @@ def _generate_and_preprocess_model_data( # type: ignore
345344
f"Could not convert {self.date_column} to datetime. Please check the date format."
346345
) from e
347346

348-
channel_data = X[self.channel_columns]
349-
350347
coords: dict[str, Any] = {
351348
"channel": self.channel_columns,
352349
"date": date_data,
353350
}
354351

355-
new_X_dict = {
356-
self.date_column: date_data,
357-
}
358-
X_data = pd.DataFrame.from_dict(new_X_dict)
359-
X_data = pd.concat([X_data, channel_data], axis=1)
360-
control_data: pd.DataFrame | pd.Series | None = None
352+
# Build X_data efficiently by selecting columns once
353+
columns_to_select = [self.date_column, *self.channel_columns]
361354
if self.control_columns is not None:
362-
control_data = X[self.control_columns]
355+
columns_to_select.extend(self.control_columns)
363356
coords["control"] = self.control_columns
364-
X_data = pd.concat([X_data, control_data], axis=1)
357+
358+
# Create X_data with proper date column in one operation
359+
# Cast to DataFrame to satisfy mypy type checking
360+
X_data = pd.DataFrame(X[columns_to_select])
361+
X_data[self.date_column] = date_data
365362

366363
self.model_coords = coords
367364
if self.validate_data:
@@ -423,11 +420,9 @@ def _compute_scale_for_data(
423420
else:
424421
raise ValueError(f"Unknown scaling method: {method}")
425422

426-
# Avoid division by zero
427-
if isinstance(scale, np.ndarray):
428-
scale = np.where(scale == 0, 1.0, scale)
429-
else:
430-
scale = 1.0 if scale == 0 else scale
423+
# Avoid division by zero using numpy.maximum for efficiency
424+
# This works for both scalars and arrays
425+
scale = np.maximum(scale, 1.0)
431426

432427
return scale
433428

@@ -438,8 +433,9 @@ def _compute_scales(self) -> None:
438433
if not isinstance(X_data, pd.DataFrame):
439434
raise TypeError("X data must be a DataFrame for scaling computation")
440435

441-
channel_data = np.asarray(X_data[self.channel_columns].values)
442-
target_data = np.asarray(self.preprocessed_data["y"]).reshape(-1, 1)
436+
# Use pandas/numpy efficient operations - avoid redundant .values call
437+
channel_data = X_data[self.channel_columns].to_numpy()
438+
target_data = np.atleast_1d(np.asarray(self.preprocessed_data["y"]))
443439

444440
# Compute scales based on scaling configuration
445441
self.channel_scale = self._compute_scale_for_data(
@@ -1052,15 +1048,19 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]:
10521048
The initialization kwargs.
10531049
10541050
"""
1051+
# Batch load JSON attributes for efficiency
1052+
model_config = json.loads(attrs["model_config"])
1053+
adstock_dict = json.loads(attrs["adstock"])
1054+
saturation_dict = json.loads(attrs["saturation"])
1055+
scaling_dict = json.loads(attrs.get("scaling", "null"))
1056+
10551057
return {
1056-
"model_config": cls._model_config_formatting(
1057-
json.loads(attrs["model_config"])
1058-
),
1058+
"model_config": cls._model_config_formatting(model_config),
10591059
"date_column": json.loads(attrs["date_column"]),
10601060
"control_columns": json.loads(attrs["control_columns"]),
10611061
"channel_columns": json.loads(attrs["channel_columns"]),
1062-
"adstock": adstock_from_dict(json.loads(attrs["adstock"])),
1063-
"saturation": saturation_from_dict(json.loads(attrs["saturation"])),
1062+
"adstock": adstock_from_dict(adstock_dict),
1063+
"saturation": saturation_from_dict(saturation_dict),
10641064
"adstock_first": json.loads(attrs.get("adstock_first", "true")),
10651065
"yearly_seasonality": json.loads(attrs["yearly_seasonality"]),
10661066
"time_varying_intercept": json.loads(
@@ -1072,9 +1072,7 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]:
10721072
"dag": json.loads(attrs.get("dag", "null")),
10731073
"treatment_nodes": json.loads(attrs.get("treatment_nodes", "null")),
10741074
"outcome_node": json.loads(attrs.get("outcome_node", "null")),
1075-
"scaling": cls._deserialize_scaling(
1076-
json.loads(attrs.get("scaling", "null"))
1077-
),
1075+
"scaling": cls._deserialize_scaling(scaling_dict),
10781076
}
10791077

10801078
def _has_new_scaling(self) -> bool:
@@ -1502,18 +1500,24 @@ def get_channel_contribution_forward_pass_grid(
15021500

15031501
share_grid = np.linspace(start=start, stop=stop, num=num)
15041502

1503+
# Extract and validate X_data once outside the loop
1504+
X_data = self.preprocessed_data["X"]
1505+
if not isinstance(X_data, pd.DataFrame):
1506+
raise TypeError("X data must be a DataFrame")
1507+
1508+
base_channel_data = X_data[self.channel_columns].to_numpy()
1509+
1510+
# Preallocate list for better performance
15051511
channel_contribution = []
15061512
for delta in share_grid:
1507-
X_data = self.preprocessed_data["X"]
1508-
if isinstance(X_data, pd.DataFrame):
1509-
channel_data = delta * X_data[self.channel_columns].to_numpy()
1510-
else:
1511-
raise TypeError("X data must be a DataFrame")
1513+
# Vectorized scaling - much faster than creating new arrays
1514+
channel_data = delta * base_channel_data
15121515
channel_contribution_forward_pass = self.channel_contribution_forward_pass(
15131516
channel_data=channel_data,
15141517
disable_logger_stdout=True,
15151518
)
15161519
channel_contribution.append(channel_contribution_forward_pass)
1520+
15171521
return DataArray(
15181522
data=np.array(channel_contribution),
15191523
dims=("delta", "chain", "draw", "date", "channel"),
@@ -1549,9 +1553,11 @@ def plot_channel_parameter(self, param_name: str, **plt_kwargs: Any) -> plt.Figu
15491553
saturation: SaturationTransformation = self.saturation
15501554
adstock: AdstockTransformation = self.adstock
15511555

1552-
parameters_to_check = list(saturation.variable_mapping.values()) + list(
1553-
adstock.variable_mapping.values()
1554-
)
1556+
# Use list extension instead of concatenation for better performance
1557+
parameters_to_check = [
1558+
*saturation.variable_mapping.values(),
1559+
*adstock.variable_mapping.values(),
1560+
]
15551561
if param_name not in parameters_to_check:
15561562
raise ValueError(
15571563
f"Invalid parameter name: {param_name}. Choose from {parameters_to_check}"
@@ -1638,17 +1644,17 @@ def _get_intercept_for_plot(
16381644
)
16391645

16401646
intercept_mean = intercept.mean(["chain", "draw"]).data
1647+
hdi_result = az.hdi(intercept).intercept.data
16411648

16421649
if intercept.ndim == 2:
1643-
# Stationary intercept - repeat for all dates
1644-
intercept_hdi = np.repeat(
1645-
a=az.hdi(intercept).intercept.data[None, ...],
1646-
repeats=self.X[self.date_column].shape[0],
1647-
axis=0,
1650+
# Stationary intercept - use broadcasting instead of repeat for efficiency
1651+
n_dates = self.X[self.date_column].shape[0]
1652+
intercept_hdi = np.broadcast_to(
1653+
hdi_result[None, :], (n_dates, hdi_result.shape[0])
16481654
)
16491655
else:
16501656
# Time-varying intercept
1651-
intercept_hdi = az.hdi(intercept).intercept.data
1657+
intercept_hdi = hdi_result
16521658

16531659
return intercept_mean, intercept_hdi
16541660

@@ -1750,11 +1756,14 @@ def plot_components_contributions(
17501756
intercept_mean, intercept_hdi = self._get_intercept_for_plot(original_scale)
17511757
color_idx = len(means)
17521758

1753-
ax.plot(
1754-
dates,
1755-
np.full(len(dates), intercept_mean),
1756-
color=f"C{color_idx}",
1757-
)
1759+
# Use scalar intercept if possible, otherwise array
1760+
if np.ndim(intercept_mean) == 0:
1761+
# Scalar intercept - matplotlib handles broadcasting automatically
1762+
ax.axhline(y=intercept_mean, color=f"C{color_idx}")
1763+
else:
1764+
# Time-varying intercept
1765+
ax.plot(dates, intercept_mean, color=f"C{color_idx}")
1766+
17581767
ax.fill_between(
17591768
x=dates,
17601769
y1=intercept_hdi[:, 0],
@@ -2139,34 +2148,37 @@ def format_recovered_transformation_parameters(
21392148
# Retrieve channel names
21402149
channels = self.fit_result.channel.values
21412150

2142-
# Initialize the dictionary to store channel information
2143-
channels_info = {}
2144-
21452151
# Define the parameter groups for consolidation
21462152
param_groups = {
21472153
"saturation_params": self.saturation.model_config.keys(),
21482154
"adstock_params": self.adstock.model_config.keys(),
21492155
}
21502156

2151-
# Iterate through each channel to fetch and store parameters
2157+
# Pre-compute quantiles for all parameters at once (more efficient)
2158+
quantile_cache = {}
2159+
for group_name, params in param_groups.items():
2160+
prefix = group_name.split("_")[0] + "_"
2161+
for param in params:
2162+
if param in self.fit_result:
2163+
# Compute quantile once and convert to pandas
2164+
quantile_cache[param] = (
2165+
self.fit_result[param]
2166+
.quantile(quantile, dim=["chain", "draw"])
2167+
.to_pandas()
2168+
)
2169+
2170+
# Build channel info dictionary efficiently
2171+
channels_info = {}
21522172
for channel in channels:
21532173
channel_info = {}
2154-
2155-
# Process each group of parameters (saturation and adstock)
21562174
for group_name, params in param_groups.items():
2157-
# Build dictionary for the current group of parameters
2175+
prefix = group_name.split("_")[0] + "_"
21582176
param_dict = {
2159-
param.replace(group_name.split("_")[0] + "_", ""): self.fit_result[
2160-
param
2161-
]
2162-
.quantile(quantile, dim=["chain", "draw"])
2163-
.to_pandas()
2164-
.to_dict()[channel]
2177+
param.replace(prefix, ""): quantile_cache[param].to_dict()[channel]
21652178
for param in params
2166-
if param in self.fit_result
2179+
if param in quantile_cache
21672180
}
21682181
channel_info[group_name] = param_dict
2169-
21702182
channels_info[channel] = channel_info
21712183

21722184
return channels_info
@@ -2720,16 +2732,22 @@ def _generate_future_dates(
27202732
list[pd.Timestamp]
27212733
List of future dates
27222734
"""
2723-
offset_map = {
2724-
"daily": lambda i: pd.DateOffset(days=i),
2725-
"weekly": lambda i: pd.DateOffset(weeks=i),
2726-
"monthly": lambda i: pd.DateOffset(months=i),
2727-
"quarterly": lambda i: pd.DateOffset(months=3 * i),
2728-
"yearly": lambda i: pd.DateOffset(years=i),
2729-
}
2730-
2731-
offset_func = offset_map[time_granularity]
2732-
return [last_date + offset_func(i) for i in range(1, time_length + 1)]
2735+
# Use pandas date_range for efficient date generation
2736+
if time_granularity == "daily":
2737+
freq = "D"
2738+
elif time_granularity == "weekly":
2739+
freq = "W"
2740+
elif time_granularity == "monthly":
2741+
freq = "MS" # Month start
2742+
elif time_granularity == "quarterly":
2743+
freq = "QS" # Quarter start
2744+
else: # yearly
2745+
freq = "YS" # Year start
2746+
2747+
# Generate dates efficiently using pandas
2748+
return pd.date_range(start=last_date, periods=time_length + 1, freq=freq)[
2749+
1:
2750+
].tolist()
27332751

27342752
def _create_synth_dataset(
27352753
self,
@@ -2805,24 +2823,27 @@ def _create_synth_dataset(
28052823
last_date, time_granularity, time_length
28062824
)
28072825

2808-
# Create synthetic rows
2809-
new_rows = [
2810-
{
2811-
self.date_column: pd.to_datetime(new_date),
2812-
**{
2813-
channel: allocation_strategy.sel(channel=channel).values
2814-
+ np.random.normal(
2815-
0, noise_level * allocation_strategy.sel(channel=channel).values
2816-
)
2817-
for channel in channels
2818-
},
2819-
**{control: 0 for control in _controls},
2820-
target_col: 0,
2821-
}
2822-
for new_date in new_dates
2823-
]
2826+
# Vectorized creation of synthetic dataset
2827+
# Extract allocation values once
2828+
channel_allocations = allocation_strategy.to_pandas()
2829+
2830+
# Create noise matrix efficiently
2831+
noise = np.random.normal(0, noise_level, size=(time_length, len(channels)))
2832+
channel_values = channel_allocations.values * (1 + noise)
2833+
2834+
# Build DataFrame efficiently using dict of arrays
2835+
data_dict: dict[str, Any] = {self.date_column: new_dates}
2836+
data_dict.update(dict(zip(channels, channel_values.T, strict=False)))
2837+
2838+
# Add controls efficiently if present (as arrays for proper type consistency)
2839+
if _controls:
2840+
zeros_array = np.zeros(time_length)
2841+
for control in _controls:
2842+
data_dict[control] = zeros_array
2843+
2844+
data_dict[target_col] = np.zeros(time_length)
28242845

2825-
return pd.DataFrame(new_rows)
2846+
return pd.DataFrame(data_dict)
28262847

28272848
def sample_response_distribution(
28282849
self,

0 commit comments

Comments
 (0)