2929import numpy as np
3030import numpy .typing as npt
3131import pymc as pm
32+ import pymc .dims as pmd
3233import xarray as xr
3334from matplotlib .axes import Axes
3435from matplotlib .figure import Figure
3536from pydantic import InstanceOf
3637from pymc .distributions .shape_utils import Dims
37- from pymc_extras .prior import Prior , VariableFactory , create_dim_handler
38+ from pymc_extras .prior import Prior , VariableFactory
3839from pytensor import tensor as pt
3940from pytensor .tensor .variable import TensorVariable
41+ from pytensor .xtensor import as_xtensor
4042
43+ from pymc_marketing .mmm .dims import XPrior
4144from pymc_marketing .model_config import parse_model_config
4245from pymc_marketing .plot import (
4346 SelToString ,
@@ -220,13 +223,13 @@ def priors(self) -> dict[str, SupportedPrior]:
220223 return self .function_priors
221224
222225 @function_priors .setter # type: ignore
223- def function_priors (self , priors : dict [str , Any | Prior ] | None ) -> None :
226+ def function_priors (self , priors : dict [str , Any | XPrior ] | None ) -> None :
224227 priors = priors or {}
225228
226229 non_distributions = [
227230 key
228231 for key , value in priors .items ()
229- if not isinstance (value , Prior ) and not isinstance (value , dict )
232+ if not isinstance (value , XPrior ) and not isinstance (value , dict )
230233 ]
231234
232235 priors = parse_model_config (priors , non_distributions = non_distributions )
@@ -313,9 +316,10 @@ def _has_defaults_for_all_arguments(self) -> None:
313316 function_signature = signature (self .function )
314317
315318 # Remove the first one as assumed to be the data
319+ # And the last one, as assumed to be dim
316320 parameters_that_need_priors = set (
317321 list (function_signature .parameters .keys ())[1 :]
318- )
322+ ) - { "dim" }
319323 parameters_with_priors = set (self .default_priors .keys ())
320324
321325 missing_priors = parameters_that_need_priors - parameters_with_priors
@@ -379,8 +383,6 @@ def _create_distributions(
379383 if idx is not None :
380384 dims = ("N" , * dims )
381385
382- dim_handler = create_dim_handler (dims )
383-
384386 def create_variable (parameter_name : str , variable_name : str ) -> TensorVariable :
385387 dist = self .function_priors [parameter_name ]
386388 if not hasattr (dist , "create_variable" ):
@@ -392,10 +394,7 @@ def create_variable(parameter_name: str, variable_name: str) -> TensorVariable:
392394 if idx is not None and any (dim in idx for dim in dist_dims ):
393395 var = index_variable (var , dist .dims , idx )
394396
395- dist_dims = [dim for dim in dist_dims if dim not in idx ]
396- dist_dims = ("N" , * dist_dims )
397-
398- return dim_handler (var , dist_dims )
397+ return var
399398
400399 return {
401400 parameter_name : create_variable (parameter_name , variable_name )
@@ -521,16 +520,19 @@ def _sample_curve(
521520 }
522521 )
523522
523+ x = as_xtensor (x , dims = (x_dim ,))
524+
524525 with pm .Model (coords = coords ):
525- pm .Deterministic (
526+ pmd .Deterministic (
526527 var_name ,
527- self .apply (x , dims = output_core_dims ),
528+ self .apply (x , dims = output_core_dims , core_dim = x_dim ),
528529 dims = (x_dim , * output_core_dims ),
529530 )
530531
531532 return pm .sample_posterior_predictive (
532533 parameters ,
533534 var_names = [var_name ],
535+ progressbar = False ,
534536 ).posterior_predictive [var_name ]
535537
536538 def plot_curve_samples (
@@ -616,7 +618,9 @@ def plot_curve_hdi(
616618 def apply (
617619 self ,
618620 x : pt .TensorLike ,
621+ * ,
619622 dims : Dims | None = None ,
623+ core_dim : str ,
620624 idx : dict [str , pt .TensorLike ] | None = None ,
621625 ) -> TensorVariable :
622626 """Call within a model context.
@@ -646,13 +650,18 @@ def apply(
646650
647651 transformation = ...
648652
649- coords = {"channel": ["TV", "Radio", "Digital"]}
653+ coords = {
654+ "channel": ["TV", "Radio", "Digital"],
655+ "date": range(10),
656+ }
650657 with pm.Model(coords=coords):
651- transformed_data = transformation.apply(data, dims="channel")
658+ transformed_data = transformation.apply(
659+ data, dims="channel", core_dim="date"
660+ )
652661
653662 """
654663 kwargs = self ._create_distributions (dims = dims , idx = idx )
655- return self .function (x , ** kwargs )
664+ return self .function (x , dim = core_dim , ** kwargs )
656665
657666
658667def _serialize_value (value : Any ) -> Any :
0 commit comments