@@ -127,6 +127,10 @@ class BudgetOptimizer(BaseModel):
127
127
Custom constraints for the optimizer.
128
128
default_constraints : bool, optional
129
129
Whether to add a default sum constraint on the total budget. Default is True.
130
+ budget_distribution_over_period : xarray.DataArray, optional
131
+ Distribution factors for budget allocation over time. Should have dims ("date", *budget_dims)
132
+ where date dimension has length num_periods. Values along date dimension should sum to 1 for
133
+ each combination of other dimensions. If None, budget is distributed evenly across periods.
130
134
"""
131
135
132
136
num_periods : int = Field (
@@ -169,6 +173,15 @@ class BudgetOptimizer(BaseModel):
169
173
description = "Whether to add a default sum constraint on the total budget." ,
170
174
)
171
175
176
+ budget_distribution_over_period : DataArray | None = Field (
177
+ default = None ,
178
+ description = (
179
+ "Distribution factors for budget allocation over time. Should have dims ('date', *budget_dims) "
180
+ "where date dimension has length num_periods. Values along date dimension should sum to 1 for "
181
+ "each combination of other dimensions. If None, budget is distributed evenly across periods."
182
+ ),
183
+ )
184
+
172
185
model_config = ConfigDict (arbitrary_types_allowed = True )
173
186
174
187
DEFAULT_MINIMIZE_KWARGS : ClassVar [dict ] = {
@@ -230,16 +243,26 @@ def __init__(self, **data):
230
243
bool_mask = np .asarray (self .budgets_to_optimize ).astype (bool )
231
244
self ._budgets = budgets_zeros [bool_mask ].set (self ._budgets_flat )
232
245
233
- # 5. Replace channel_data with budgets in the PyMC model
246
+ # 5. Validate and process budget_distribution_over_period
247
+ self ._budget_distribution_over_period_tensor = (
248
+ self ._validate_and_process_budget_distribution (
249
+ budget_distribution_over_period = self .budget_distribution_over_period ,
250
+ num_periods = self .num_periods ,
251
+ budget_dims = self ._budget_dims ,
252
+ budgets_to_optimize = self .budgets_to_optimize ,
253
+ )
254
+ )
255
+
256
+ # 6. Replace channel_data with budgets in the PyMC model
234
257
self ._pymc_model = self ._replace_channel_data_by_optimization_variable (
235
258
pymc_model
236
259
)
237
260
238
- # 6 . Compile objective & gradient
261
+ # 7 . Compile objective & gradient
239
262
self ._compiled_functions = {}
240
263
self ._compile_objective_and_grad ()
241
264
242
- # 7 . Build constraints
265
+ # 8 . Build constraints
243
266
self ._constraints = {}
244
267
self .set_constraints (
245
268
default = self .default_constraints , constraints = self .custom_constraints
@@ -272,6 +295,126 @@ def set_constraints(self, constraints, default=None) -> None:
272
295
constraints = self ._constraints , optimizer = self
273
296
)
274
297
298
+ def _validate_and_process_budget_distribution (
299
+ self ,
300
+ budget_distribution_over_period : DataArray | None ,
301
+ num_periods : int ,
302
+ budget_dims : list [str ],
303
+ budgets_to_optimize : DataArray ,
304
+ ) -> pt .TensorVariable | None :
305
+ """Validate and process budget distribution over periods.
306
+
307
+ Parameters
308
+ ----------
309
+ budget_distribution_over_period : DataArray | None
310
+ Distribution factors for budget allocation over time.
311
+ num_periods : int
312
+ Number of time periods to allocate budget for.
313
+ budget_dims : list[str]
314
+ List of budget dimensions (excluding 'date').
315
+ budgets_to_optimize : DataArray
316
+ Mask defining which budgets to optimize.
317
+
318
+ Returns
319
+ -------
320
+ pt.TensorVariable | None
321
+ Processed tensor containing masked time factors, or None if no distribution provided.
322
+ """
323
+ if budget_distribution_over_period is None :
324
+ return None
325
+
326
+ # Validate dimensions - date should be first
327
+ expected_dims = ("date" , * budget_dims )
328
+ if set (budget_distribution_over_period .dims ) != set (expected_dims ):
329
+ raise ValueError (
330
+ f"budget_distribution_over_period must have dims { expected_dims } , "
331
+ f"but got { budget_distribution_over_period .dims } "
332
+ )
333
+
334
+ # Validate date dimension length
335
+ if len (budget_distribution_over_period .coords ["date" ]) != num_periods :
336
+ raise ValueError (
337
+ f"budget_distribution_over_period date dimension must have length { num_periods } , "
338
+ f"but got { len (budget_distribution_over_period .coords ['date' ])} "
339
+ )
340
+
341
+ # Validate that factors sum to 1 along date dimension
342
+ sums = budget_distribution_over_period .sum (dim = "date" )
343
+ if not np .allclose (sums .values , 1.0 , rtol = 1e-5 ):
344
+ raise ValueError (
345
+ "budget_distribution_over_period must sum to 1 along the date dimension "
346
+ "for each combination of other dimensions"
347
+ )
348
+
349
+ # Pre-process: Apply the mask to get only factors for optimized budgets
350
+ # This avoids shape mismatches during gradient computation
351
+ time_factors_full = budget_distribution_over_period .transpose (
352
+ * expected_dims
353
+ ).values
354
+
355
+ # Reshape to (num_periods, flat_budget_dims) and apply mask
356
+ time_factors_flat = time_factors_full .reshape ((num_periods , - 1 ))
357
+ bool_mask = budgets_to_optimize .values .flatten ()
358
+ time_factors_masked = time_factors_flat [:, bool_mask ]
359
+
360
+ # Store only the masked tensor
361
+ return pt .constant (time_factors_masked , name = "budget_distribution_over_period" )
362
+
363
+ def _apply_budget_distribution_over_period (
364
+ self ,
365
+ budgets : pt .TensorVariable ,
366
+ num_periods : int ,
367
+ date_dim_idx : int ,
368
+ ) -> pt .TensorVariable :
369
+ """Apply budget distribution over periods to budgets across time periods.
370
+
371
+ Parameters
372
+ ----------
373
+ budgets : pt.TensorVariable
374
+ The scaled budget tensor with shape matching budget dimensions.
375
+ num_periods : int
376
+ Number of time periods to distribute budget across.
377
+ date_dim_idx : int
378
+ Index position where the date dimension should be inserted.
379
+
380
+ Returns
381
+ -------
382
+ pt.TensorVariable
383
+ Budget tensor repeated across time periods with distribution factors applied.
384
+ Shape will be (*budget_dims[:date_dim_idx], num_periods, *budget_dims[date_dim_idx:])
385
+ """
386
+ # Apply time distribution factors
387
+ # The time factors are already masked and have shape (num_periods, num_optimized_budgets)
388
+ # budgets has full shape (e.g., (2, 2) for geo x channel)
389
+ # We need to extract only the optimized budgets
390
+
391
+ # Get the optimized budget values
392
+ bool_mask = np .asarray (self .budgets_to_optimize ).astype (bool )
393
+ budgets_optimized = budgets [bool_mask ] # Shape: (num_optimized_budgets,)
394
+
395
+ # Now multiply budgets by time factors
396
+ budgets_expanded = pt .expand_dims (
397
+ budgets_optimized , 0
398
+ ) # Shape: (1, num_optimized_budgets)
399
+ repeated_budgets_flat = (
400
+ budgets_expanded * self ._budget_distribution_over_period_tensor
401
+ ) # Shape: (num_periods, num_optimized_budgets)
402
+
403
+ # Reconstruct the full shape for each time period
404
+ repeated_budgets_list = []
405
+ for t in range (num_periods ):
406
+ # Create a zero tensor with the full budget shape
407
+ budgets_t = pt .zeros_like (budgets )
408
+ # Set the optimized values
409
+ budgets_t = budgets_t [bool_mask ].set (repeated_budgets_flat [t ])
410
+ repeated_budgets_list .append (budgets_t )
411
+
412
+ # Stack the time periods
413
+ repeated_budgets = pt .stack (repeated_budgets_list , axis = date_dim_idx )
414
+ repeated_budgets *= num_periods
415
+
416
+ return repeated_budgets
417
+
275
418
def _replace_channel_data_by_optimization_variable (self , model : Model ) -> Model :
276
419
"""Replace `channel_data` in the model graph with our newly created `_budgets` variable."""
277
420
num_periods = self .num_periods
@@ -287,10 +430,19 @@ def _replace_channel_data_by_optimization_variable(self, model: Model) -> Model:
287
430
# Repeat budgets over num_periods
288
431
repeated_budgets_shape = list (tuple (budgets .shape ))
289
432
repeated_budgets_shape .insert (date_dim_idx , num_periods )
290
- repeated_budgets = pt .broadcast_to (
291
- pt .expand_dims (budgets , date_dim_idx ),
292
- shape = repeated_budgets_shape ,
293
- )
433
+
434
+ if self ._budget_distribution_over_period_tensor is not None :
435
+ # Apply time distribution factors
436
+ repeated_budgets = self ._apply_budget_distribution_over_period (
437
+ budgets , num_periods , date_dim_idx
438
+ )
439
+ else :
440
+ # Default behavior: distribute evenly across periods
441
+ repeated_budgets = pt .broadcast_to (
442
+ pt .expand_dims (budgets , date_dim_idx ),
443
+ shape = repeated_budgets_shape ,
444
+ )
445
+
294
446
repeated_budgets .name = "repeated_budgets"
295
447
296
448
# Pad the repeated budgets with zeros to account for carry-over effects
0 commit comments