64
64
65
65
__all__ = ["MMM" , "BaseMMM" ]
66
66
67
- # Constants
68
67
DEFAULT_HDI_PROB = 0.94
69
68
70
69
@@ -345,23 +344,21 @@ def _generate_and_preprocess_model_data( # type: ignore
345
344
f"Could not convert { self .date_column } to datetime. Please check the date format."
346
345
) from e
347
346
348
- channel_data = X [self .channel_columns ]
349
-
350
347
coords : dict [str , Any ] = {
351
348
"channel" : self .channel_columns ,
352
349
"date" : date_data ,
353
350
}
354
351
355
- new_X_dict = {
356
- self .date_column : date_data ,
357
- }
358
- X_data = pd .DataFrame .from_dict (new_X_dict )
359
- X_data = pd .concat ([X_data , channel_data ], axis = 1 )
360
- control_data : pd .DataFrame | pd .Series | None = None
352
+ # Build X_data efficiently by selecting columns once
353
+ columns_to_select = [self .date_column , * self .channel_columns ]
361
354
if self .control_columns is not None :
362
- control_data = X [ self .control_columns ]
355
+ columns_to_select . extend ( self .control_columns )
363
356
coords ["control" ] = self .control_columns
364
- X_data = pd .concat ([X_data , control_data ], axis = 1 )
357
+
358
+ # Create X_data with proper date column in one operation
359
+ # Cast to DataFrame to satisfy mypy type checking
360
+ X_data = pd .DataFrame (X [columns_to_select ])
361
+ X_data [self .date_column ] = date_data
365
362
366
363
self .model_coords = coords
367
364
if self .validate_data :
@@ -423,11 +420,9 @@ def _compute_scale_for_data(
423
420
else :
424
421
raise ValueError (f"Unknown scaling method: { method } " )
425
422
426
- # Avoid division by zero
427
- if isinstance (scale , np .ndarray ):
428
- scale = np .where (scale == 0 , 1.0 , scale )
429
- else :
430
- scale = 1.0 if scale == 0 else scale
423
+ # Avoid division by zero using numpy.maximum for efficiency
424
+ # This works for both scalars and arrays
425
+ scale = np .maximum (scale , 1.0 )
431
426
432
427
return scale
433
428
@@ -438,8 +433,9 @@ def _compute_scales(self) -> None:
438
433
if not isinstance (X_data , pd .DataFrame ):
439
434
raise TypeError ("X data must be a DataFrame for scaling computation" )
440
435
441
- channel_data = np .asarray (X_data [self .channel_columns ].values )
442
- target_data = np .asarray (self .preprocessed_data ["y" ]).reshape (- 1 , 1 )
436
+ # Use pandas/numpy efficient operations - avoid redundant .values call
437
+ channel_data = X_data [self .channel_columns ].to_numpy ()
438
+ target_data = np .atleast_1d (np .asarray (self .preprocessed_data ["y" ]))
443
439
444
440
# Compute scales based on scaling configuration
445
441
self .channel_scale = self ._compute_scale_for_data (
@@ -1052,15 +1048,19 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]:
1052
1048
The initialization kwargs.
1053
1049
1054
1050
"""
1051
+ # Batch load JSON attributes for efficiency
1052
+ model_config = json .loads (attrs ["model_config" ])
1053
+ adstock_dict = json .loads (attrs ["adstock" ])
1054
+ saturation_dict = json .loads (attrs ["saturation" ])
1055
+ scaling_dict = json .loads (attrs .get ("scaling" , "null" ))
1056
+
1055
1057
return {
1056
- "model_config" : cls ._model_config_formatting (
1057
- json .loads (attrs ["model_config" ])
1058
- ),
1058
+ "model_config" : cls ._model_config_formatting (model_config ),
1059
1059
"date_column" : json .loads (attrs ["date_column" ]),
1060
1060
"control_columns" : json .loads (attrs ["control_columns" ]),
1061
1061
"channel_columns" : json .loads (attrs ["channel_columns" ]),
1062
- "adstock" : adstock_from_dict (json . loads ( attrs [ "adstock" ]) ),
1063
- "saturation" : saturation_from_dict (json . loads ( attrs [ "saturation" ]) ),
1062
+ "adstock" : adstock_from_dict (adstock_dict ),
1063
+ "saturation" : saturation_from_dict (saturation_dict ),
1064
1064
"adstock_first" : json .loads (attrs .get ("adstock_first" , "true" )),
1065
1065
"yearly_seasonality" : json .loads (attrs ["yearly_seasonality" ]),
1066
1066
"time_varying_intercept" : json .loads (
@@ -1072,9 +1072,7 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]:
1072
1072
"dag" : json .loads (attrs .get ("dag" , "null" )),
1073
1073
"treatment_nodes" : json .loads (attrs .get ("treatment_nodes" , "null" )),
1074
1074
"outcome_node" : json .loads (attrs .get ("outcome_node" , "null" )),
1075
- "scaling" : cls ._deserialize_scaling (
1076
- json .loads (attrs .get ("scaling" , "null" ))
1077
- ),
1075
+ "scaling" : cls ._deserialize_scaling (scaling_dict ),
1078
1076
}
1079
1077
1080
1078
def _has_new_scaling (self ) -> bool :
@@ -1502,18 +1500,24 @@ def get_channel_contribution_forward_pass_grid(
1502
1500
1503
1501
share_grid = np .linspace (start = start , stop = stop , num = num )
1504
1502
1503
+ # Extract and validate X_data once outside the loop
1504
+ X_data = self .preprocessed_data ["X" ]
1505
+ if not isinstance (X_data , pd .DataFrame ):
1506
+ raise TypeError ("X data must be a DataFrame" )
1507
+
1508
+ base_channel_data = X_data [self .channel_columns ].to_numpy ()
1509
+
1510
+ # Preallocate list for better performance
1505
1511
channel_contribution = []
1506
1512
for delta in share_grid :
1507
- X_data = self .preprocessed_data ["X" ]
1508
- if isinstance (X_data , pd .DataFrame ):
1509
- channel_data = delta * X_data [self .channel_columns ].to_numpy ()
1510
- else :
1511
- raise TypeError ("X data must be a DataFrame" )
1513
+ # Vectorized scaling - much faster than creating new arrays
1514
+ channel_data = delta * base_channel_data
1512
1515
channel_contribution_forward_pass = self .channel_contribution_forward_pass (
1513
1516
channel_data = channel_data ,
1514
1517
disable_logger_stdout = True ,
1515
1518
)
1516
1519
channel_contribution .append (channel_contribution_forward_pass )
1520
+
1517
1521
return DataArray (
1518
1522
data = np .array (channel_contribution ),
1519
1523
dims = ("delta" , "chain" , "draw" , "date" , "channel" ),
@@ -1549,9 +1553,11 @@ def plot_channel_parameter(self, param_name: str, **plt_kwargs: Any) -> plt.Figu
1549
1553
saturation : SaturationTransformation = self .saturation
1550
1554
adstock : AdstockTransformation = self .adstock
1551
1555
1552
- parameters_to_check = list (saturation .variable_mapping .values ()) + list (
1553
- adstock .variable_mapping .values ()
1554
- )
1556
+ # Use list extension instead of concatenation for better performance
1557
+ parameters_to_check = [
1558
+ * saturation .variable_mapping .values (),
1559
+ * adstock .variable_mapping .values (),
1560
+ ]
1555
1561
if param_name not in parameters_to_check :
1556
1562
raise ValueError (
1557
1563
f"Invalid parameter name: { param_name } . Choose from { parameters_to_check } "
@@ -1638,17 +1644,17 @@ def _get_intercept_for_plot(
1638
1644
)
1639
1645
1640
1646
intercept_mean = intercept .mean (["chain" , "draw" ]).data
1647
+ hdi_result = az .hdi (intercept ).intercept .data
1641
1648
1642
1649
if intercept .ndim == 2 :
1643
- # Stationary intercept - repeat for all dates
1644
- intercept_hdi = np .repeat (
1645
- a = az .hdi (intercept ).intercept .data [None , ...],
1646
- repeats = self .X [self .date_column ].shape [0 ],
1647
- axis = 0 ,
1650
+ # Stationary intercept - use broadcasting instead of repeat for efficiency
1651
+ n_dates = self .X [self .date_column ].shape [0 ]
1652
+ intercept_hdi = np .broadcast_to (
1653
+ hdi_result [None , :], (n_dates , hdi_result .shape [0 ])
1648
1654
)
1649
1655
else :
1650
1656
# Time-varying intercept
1651
- intercept_hdi = az . hdi ( intercept ). intercept . data
1657
+ intercept_hdi = hdi_result
1652
1658
1653
1659
return intercept_mean , intercept_hdi
1654
1660
@@ -1750,11 +1756,14 @@ def plot_components_contributions(
1750
1756
intercept_mean , intercept_hdi = self ._get_intercept_for_plot (original_scale )
1751
1757
color_idx = len (means )
1752
1758
1753
- ax .plot (
1754
- dates ,
1755
- np .full (len (dates ), intercept_mean ),
1756
- color = f"C{ color_idx } " ,
1757
- )
1759
+ # Use scalar intercept if possible, otherwise array
1760
+ if np .ndim (intercept_mean ) == 0 :
1761
+ # Scalar intercept - matplotlib handles broadcasting automatically
1762
+ ax .axhline (y = intercept_mean , color = f"C{ color_idx } " )
1763
+ else :
1764
+ # Time-varying intercept
1765
+ ax .plot (dates , intercept_mean , color = f"C{ color_idx } " )
1766
+
1758
1767
ax .fill_between (
1759
1768
x = dates ,
1760
1769
y1 = intercept_hdi [:, 0 ],
@@ -2139,34 +2148,37 @@ def format_recovered_transformation_parameters(
2139
2148
# Retrieve channel names
2140
2149
channels = self .fit_result .channel .values
2141
2150
2142
- # Initialize the dictionary to store channel information
2143
- channels_info = {}
2144
-
2145
2151
# Define the parameter groups for consolidation
2146
2152
param_groups = {
2147
2153
"saturation_params" : self .saturation .model_config .keys (),
2148
2154
"adstock_params" : self .adstock .model_config .keys (),
2149
2155
}
2150
2156
2151
- # Iterate through each channel to fetch and store parameters
2157
+ # Pre-compute quantiles for all parameters at once (more efficient)
2158
+ quantile_cache = {}
2159
+ for group_name , params in param_groups .items ():
2160
+ prefix = group_name .split ("_" )[0 ] + "_"
2161
+ for param in params :
2162
+ if param in self .fit_result :
2163
+ # Compute quantile once and convert to pandas
2164
+ quantile_cache [param ] = (
2165
+ self .fit_result [param ]
2166
+ .quantile (quantile , dim = ["chain" , "draw" ])
2167
+ .to_pandas ()
2168
+ )
2169
+
2170
+ # Build channel info dictionary efficiently
2171
+ channels_info = {}
2152
2172
for channel in channels :
2153
2173
channel_info = {}
2154
-
2155
- # Process each group of parameters (saturation and adstock)
2156
2174
for group_name , params in param_groups .items ():
2157
- # Build dictionary for the current group of parameters
2175
+ prefix = group_name . split ( "_" )[ 0 ] + "_"
2158
2176
param_dict = {
2159
- param .replace (group_name .split ("_" )[0 ] + "_" , "" ): self .fit_result [
2160
- param
2161
- ]
2162
- .quantile (quantile , dim = ["chain" , "draw" ])
2163
- .to_pandas ()
2164
- .to_dict ()[channel ]
2177
+ param .replace (prefix , "" ): quantile_cache [param ].to_dict ()[channel ]
2165
2178
for param in params
2166
- if param in self . fit_result
2179
+ if param in quantile_cache
2167
2180
}
2168
2181
channel_info [group_name ] = param_dict
2169
-
2170
2182
channels_info [channel ] = channel_info
2171
2183
2172
2184
return channels_info
@@ -2720,16 +2732,22 @@ def _generate_future_dates(
2720
2732
list[pd.Timestamp]
2721
2733
List of future dates
2722
2734
"""
2723
- offset_map = {
2724
- "daily" : lambda i : pd .DateOffset (days = i ),
2725
- "weekly" : lambda i : pd .DateOffset (weeks = i ),
2726
- "monthly" : lambda i : pd .DateOffset (months = i ),
2727
- "quarterly" : lambda i : pd .DateOffset (months = 3 * i ),
2728
- "yearly" : lambda i : pd .DateOffset (years = i ),
2729
- }
2730
-
2731
- offset_func = offset_map [time_granularity ]
2732
- return [last_date + offset_func (i ) for i in range (1 , time_length + 1 )]
2735
+ # Use pandas date_range for efficient date generation
2736
+ if time_granularity == "daily" :
2737
+ freq = "D"
2738
+ elif time_granularity == "weekly" :
2739
+ freq = "W"
2740
+ elif time_granularity == "monthly" :
2741
+ freq = "MS" # Month start
2742
+ elif time_granularity == "quarterly" :
2743
+ freq = "QS" # Quarter start
2744
+ else : # yearly
2745
+ freq = "YS" # Year start
2746
+
2747
+ # Generate dates efficiently using pandas
2748
+ return pd .date_range (start = last_date , periods = time_length + 1 , freq = freq )[
2749
+ 1 :
2750
+ ].tolist ()
2733
2751
2734
2752
def _create_synth_dataset (
2735
2753
self ,
@@ -2805,24 +2823,27 @@ def _create_synth_dataset(
2805
2823
last_date , time_granularity , time_length
2806
2824
)
2807
2825
2808
- # Create synthetic rows
2809
- new_rows = [
2810
- {
2811
- self .date_column : pd .to_datetime (new_date ),
2812
- ** {
2813
- channel : allocation_strategy .sel (channel = channel ).values
2814
- + np .random .normal (
2815
- 0 , noise_level * allocation_strategy .sel (channel = channel ).values
2816
- )
2817
- for channel in channels
2818
- },
2819
- ** {control : 0 for control in _controls },
2820
- target_col : 0 ,
2821
- }
2822
- for new_date in new_dates
2823
- ]
2826
+ # Vectorized creation of synthetic dataset
2827
+ # Extract allocation values once
2828
+ channel_allocations = allocation_strategy .to_pandas ()
2829
+
2830
+ # Create noise matrix efficiently
2831
+ noise = np .random .normal (0 , noise_level , size = (time_length , len (channels )))
2832
+ channel_values = channel_allocations .values * (1 + noise )
2833
+
2834
+ # Build DataFrame efficiently using dict of arrays
2835
+ data_dict : dict [str , Any ] = {self .date_column : new_dates }
2836
+ data_dict .update (dict (zip (channels , channel_values .T , strict = False )))
2837
+
2838
+ # Add controls efficiently if present (as arrays for proper type consistency)
2839
+ if _controls :
2840
+ zeros_array = np .zeros (time_length )
2841
+ for control in _controls :
2842
+ data_dict [control ] = zeros_array
2843
+
2844
+ data_dict [target_col ] = np .zeros (time_length )
2824
2845
2825
- return pd .DataFrame (new_rows )
2846
+ return pd .DataFrame (data_dict )
2826
2847
2827
2848
def sample_response_distribution (
2828
2849
self ,
0 commit comments