@@ -523,24 +523,30 @@ class BayesianBasisExpansionTimeSeries(PyMCModel):
523523 ----------
524524 n_order : int, optional
525525 The number of Fourier components for the yearly seasonality. Defaults to 3.
526+ Only used if seasonality_component is None.
526527 n_changepoints_trend : int, optional
527528 The number of changepoints for the linear trend component. Defaults to 10.
529+ Only used if trend_component is None.
530+ prior_sigma : float, optional
531+ Prior standard deviation for the observation noise. Defaults to 5.
532+ trend_component : Optional[Any], optional
533+ A custom trend component model. If None, the default pymc-marketing LinearTrend component is used.
534+ Must have an `apply(time_data)` method that returns a PyMC tensor.
535+ seasonality_component : Optional[Any], optional
536+ A custom seasonality component model. If None, the default pymc-marketing YearlyFourier component is used.
537+ Must have an `apply(time_data)` method that returns a PyMC tensor.
528538 sample_kwargs : dict, optional
529539 A dictionary of kwargs that get unpacked and passed to the
530540 :func:`pymc.sample` function. Defaults to an empty dictionary.
531- trend_component : Optional[Any], optional
532- A custom trend component model. If None, the default pymc-marketing trend component is used.
533- seasonality_component : Optional[Any], optional
534- A custom seasonality component model. If None, the default pymc-marketing seasonality `YearlyFourier` component is used.
535541 """ # noqa: W605
536542
537543 def __init__ (
538544 self ,
539545 n_order : int = 3 ,
540546 n_changepoints_trend : int = 10 ,
541547 prior_sigma : float = 5 ,
542- # Removed trend_component and seasonality_component for now to simplify
543- # They can be added back if pymc-marketing is a hard dependency or via other logic
548+ trend_component : Optional [ Any ] = None ,
549+ seasonality_component : Optional [ Any ] = None ,
544550 sample_kwargs : Optional [Dict [str , Any ]] = None ,
545551 ):
546552 super ().__init__ (sample_kwargs = sample_kwargs )
@@ -552,9 +558,74 @@ def __init__(
552558 self ._first_fit_timestamp : Optional [pd .Timestamp ] = None
553559 self ._exog_var_names : Optional [List [str ]] = None
554560
555- # pymc-marketing components will be initialized in build_model
556- # self._yearly_fourier = None
557- # self._linear_trend = None
561+ # Store custom components (fix the bug where they were swapped)
562+ self ._custom_trend_component = trend_component
563+ self ._custom_seasonality_component = seasonality_component
564+
565+ # Initialize and validate components
566+ self ._trend_component = None
567+ self ._seasonality_component = None
568+ self ._validate_and_initialize_components ()
569+
570+ def _validate_and_initialize_components (self ):
571+ """
572+ Validate and initialize trend and seasonality components.
573+ This separates validation from model building for cleaner code.
574+ """
575+ # Validate pymc-marketing availability if using default components
576+ if (
577+ self ._custom_trend_component is None
578+ or self ._custom_seasonality_component is None
579+ ):
580+ try :
581+ from pymc_marketing .mmm import LinearTrend , YearlyFourier
582+
583+ self ._PymcMarketingLinearTrend = LinearTrend
584+ self ._PymcMarketingYearlyFourier = YearlyFourier
585+ except ImportError :
586+ raise ImportError (
587+ "pymc-marketing is required when using default trend or seasonality components. "
588+ "Please install it with `pip install pymc-marketing` or provide custom components."
589+ )
590+
591+ # Validate custom components have required methods
592+ if self ._custom_trend_component is not None :
593+ if not hasattr (self ._custom_trend_component , "apply" ):
594+ raise ValueError (
595+ "Custom trend_component must have an 'apply' method that accepts time data "
596+ "and returns a PyMC tensor."
597+ )
598+
599+ if self ._custom_seasonality_component is not None :
600+ if not hasattr (self ._custom_seasonality_component , "apply" ):
601+ raise ValueError (
602+ "Custom seasonality_component must have an 'apply' method that accepts time data "
603+ "and returns a PyMC tensor."
604+ )
605+
606+ def _get_trend_component (self ):
607+ """Get the trend component, creating default if needed."""
608+ if self ._custom_trend_component is not None :
609+ return self ._custom_trend_component
610+
611+ # Create default trend component
612+ if self ._trend_component is None :
613+ self ._trend_component = self ._PymcMarketingLinearTrend (
614+ n_changepoints = self .n_changepoints_trend
615+ )
616+ return self ._trend_component
617+
618+ def _get_seasonality_component (self ):
619+ """Get the seasonality component, creating default if needed."""
620+ if self ._custom_seasonality_component is not None :
621+ return self ._custom_seasonality_component
622+
623+ # Create default seasonality component
624+ if self ._seasonality_component is None :
625+ self ._seasonality_component = self ._PymcMarketingYearlyFourier (
626+ n_order = self .n_order
627+ )
628+ return self ._seasonality_component
558629
559630 def _prepare_time_and_exog_features (
560631 self ,
@@ -665,9 +736,6 @@ def build_model(
665736
666737 # Get exog_names from coords["coeffs"] if X_exog_array is present
667738 exog_names_from_coords = coords .get ("coeffs" )
668- # This will be further processed into a list by _prepare_time_and_exog_features
669- # if isinstance(exog_names_from_coords, str): # Handle single coeff name
670- # exog_names_from_coords = [exog_names_from_coords]
671739
672740 (
673741 time_for_trend ,
@@ -738,44 +806,19 @@ def build_model(
738806 "t_season_data" , time_for_seasonality , dims = "obs_ind" , mutable = True
739807 )
740808
741- # Attempt to import and instantiate pymc_marketing components here
742- _PymcMarketingLinearTrend = None
743- _PymcMarketingYearlyFourier = None
744- pymc_marketing_available = False
745- try :
746- from pymc_marketing .mmm import LinearTrend as PymcMLinearTrend
747- from pymc_marketing .mmm import YearlyFourier as PymcMYearlyFourier
748-
749- _PymcMarketingLinearTrend = PymcMLinearTrend
750- _PymcMarketingYearlyFourier = PymcMYearlyFourier
751- pymc_marketing_available = True
752- except ImportError :
753- # pymc-marketing is not available. This is handled conditionally below.
754- pass
755-
756- if not pymc_marketing_available :
757- raise ImportError (
758- "pymc-marketing is required. "
759- "Please install it with `pip install pymc-marketing`."
760- )
761-
762- # Instantiate components for this specific build_model call
763- local_yearly_fourier = _PymcMarketingYearlyFourier (n_order = self .n_order )
764- local_linear_trend = _PymcMarketingLinearTrend (
765- n_changepoints = self .n_changepoints_trend
766- )
809+ # Get validated components (no more ugly imports in build_model!)
810+ trend_component_instance = self ._get_trend_component ()
811+ seasonality_component_instance = self ._get_seasonality_component ()
767812
768813 # Seasonal component
769814 season_component = pm .Deterministic (
770815 "season_component" ,
771- local_yearly_fourier .apply (t_season_data ), # Use local instance
816+ seasonality_component_instance .apply (t_season_data ),
772817 dims = "obs_ind" ,
773818 )
774819
775820 # Trend component
776- trend_component_values = local_linear_trend .apply (
777- t_trend_data
778- ) # Use local instance
821+ trend_component_values = trend_component_instance .apply (t_trend_data )
779822 trend_component = pm .Deterministic (
780823 "trend_component" ,
781824 trend_component_values ,
0 commit comments