@@ -764,11 +764,11 @@ def plot_transforms(
764764 x_range = None ,
765765 ** kwargs ,
766766 ) -> Tuple [plt .Figure , np .ndarray ]:
767- """Plot estimated saturation and adstock transformation curves .
767+ """Plot estimated transformation curves ( saturation and/or adstock) .
768768
769- Creates a 2-panel figure showing :
770- 1. Saturation curve (input exposure -> saturated exposure)
771- 2. Adstock weights over time (lag distribution)
769+ Creates a figure with 1-2 panels depending on which transforms are present :
770+ - Saturation curve (input exposure -> saturated exposure) if saturation exists
771+ - Adstock weights over time (lag distribution) if adstock exists
772772
773773 Parameters
774774 ----------
@@ -784,8 +784,8 @@ def plot_transforms(
784784 Returns
785785 -------
786786 fig : matplotlib.figure.Figure
787- ax : array of matplotlib.axes.Axes
788- Array of 2 axes objects (left: saturation, right: adstock ).
787+ ax : list of matplotlib.axes.Axes
788+ List of axes objects (1 or 2 panels depending on which transforms exist ).
789789
790790 Examples
791791 --------
@@ -810,13 +810,33 @@ def plot_transforms(
810810 est_saturation = treatment .saturation
811811 est_adstock = treatment .adstock
812812
813- # Create 2-panel subplot
814- fig , axes = plt .subplots (1 , 2 , figsize = (14 , 5 ))
813+ # Check which transforms exist
814+ has_saturation = est_saturation is not None
815+ has_adstock = est_adstock is not None
816+
817+ if not has_saturation and not has_adstock :
818+ raise ValueError (
819+ "No transforms to plot (both saturation and adstock are None). "
820+ "At least one transform must be specified."
821+ )
822+
823+ # Determine number of panels based on available transforms
824+ n_panels = int (has_saturation ) + int (has_adstock )
825+
826+ # Create subplot with appropriate number of panels
827+ fig , axes = plt .subplots (1 , n_panels , figsize = (7 * n_panels , 5 ))
828+
829+ # Make axes a list for consistent indexing
830+ if n_panels == 1 :
831+ axes = [axes ]
832+
833+ panel_idx = 0
815834
816835 # ============================================================================
817- # LEFT PLOT: Saturation curves
836+ # SATURATION PLOT (if present)
818837 # ============================================================================
819838 if est_saturation is not None :
839+ ax = axes [panel_idx ]
820840 # Determine x range
821841 if x_range is None :
822842 # Use range from data
@@ -831,7 +851,7 @@ def plot_transforms(
831851 # Plot true saturation if provided
832852 if true_saturation is not None :
833853 y_true_sat = true_saturation .apply (x_sat )
834- axes [ 0 ] .plot (
854+ ax .plot (
835855 x_sat ,
836856 y_true_sat ,
837857 "k--" ,
@@ -842,13 +862,13 @@ def plot_transforms(
842862
843863 # Plot estimated saturation
844864 y_est_sat = est_saturation .apply (x_sat )
845- axes [ 0 ] .plot (x_sat , y_est_sat , "C0-" , linewidth = 2.5 , label = "Estimated" )
865+ ax .plot (x_sat , y_est_sat , "C0-" , linewidth = 2.5 , label = "Estimated" )
846866
847- axes [ 0 ] .set_xlabel (f"{ treatment .name } (raw)" , fontsize = 11 )
848- axes [ 0 ] .set_ylabel ("Saturated Value" , fontsize = 11 )
849- axes [ 0 ] .set_title ("Saturation Function" , fontsize = 12 , fontweight = "bold" )
850- axes [ 0 ] .legend (fontsize = LEGEND_FONT_SIZE , framealpha = 0.9 )
851- axes [ 0 ] .grid (True , alpha = 0.3 )
867+ ax .set_xlabel (f"{ treatment .name } (raw)" , fontsize = 11 )
868+ ax .set_ylabel ("Saturated Value" , fontsize = 11 )
869+ ax .set_title ("Saturation Function" , fontsize = 12 , fontweight = "bold" )
870+ ax .legend (fontsize = LEGEND_FONT_SIZE , framealpha = 0.9 )
871+ ax .grid (True , alpha = 0.3 )
852872
853873 # Add parameter text
854874 est_params = est_saturation .get_params ()
@@ -864,30 +884,22 @@ def plot_transforms(
864884 if key not in ["alpha" , "l_max" , "normalize" ]:
865885 param_text += f" { key } ={ val :.2f} \n "
866886
867- axes [ 0 ] .text (
887+ ax .text (
868888 0.05 ,
869889 0.95 ,
870890 param_text .strip (),
871- transform = axes [ 0 ] .transAxes ,
891+ transform = ax .transAxes ,
872892 fontsize = 9 ,
873893 verticalalignment = "top" ,
874894 bbox = dict (boxstyle = "round" , facecolor = "wheat" , alpha = 0.5 ),
875895 )
876- else :
877- axes [0 ].text (
878- 0.5 ,
879- 0.5 ,
880- "No saturation transform" ,
881- ha = "center" ,
882- va = "center" ,
883- transform = axes [0 ].transAxes ,
884- )
885- axes [0 ].set_title ("Saturation Function" , fontsize = 12 , fontweight = "bold" )
896+ panel_idx += 1
886897
887898 # ============================================================================
888- # RIGHT PLOT: Adstock weights
899+ # ADSTOCK PLOT (if present)
889900 # ============================================================================
890901 if est_adstock is not None :
902+ ax = axes [panel_idx ]
891903 est_adstock_params = est_adstock .get_params ()
892904 l_max = est_adstock_params .get ("l_max" , 12 )
893905 lags = np .arange (l_max + 1 )
@@ -908,15 +920,15 @@ def plot_transforms(
908920 true_weights = true_weights / true_weights .sum ()
909921
910922 width = 0.35
911- axes [ 1 ] .bar (
923+ ax .bar (
912924 lags - width / 2 ,
913925 true_weights ,
914926 width ,
915927 alpha = 0.8 ,
916928 label = "True" ,
917929 color = "gray" ,
918930 )
919- axes [ 1 ] .bar (
931+ ax .bar (
920932 lags + width / 2 ,
921933 est_weights ,
922934 width ,
@@ -925,15 +937,15 @@ def plot_transforms(
925937 color = "C0" ,
926938 )
927939 else :
928- axes [ 1 ] .bar (lags , est_weights , alpha = 0.7 , color = "C0" , label = "Estimated" )
940+ ax .bar (lags , est_weights , alpha = 0.7 , color = "C0" , label = "Estimated" )
929941
930- axes [ 1 ] .set_xlabel ("Lag (periods)" , fontsize = 11 )
931- axes [ 1 ] .set_ylabel ("Adstock Weight" , fontsize = 11 )
932- axes [ 1 ] .set_title (
942+ ax .set_xlabel ("Lag (periods)" , fontsize = 11 )
943+ ax .set_ylabel ("Adstock Weight" , fontsize = 11 )
944+ ax .set_title (
933945 "Adstock Function (Carryover Effect)" , fontsize = 12 , fontweight = "bold"
934946 )
935- axes [ 1 ] .legend (fontsize = LEGEND_FONT_SIZE , framealpha = 0.9 )
936- axes [ 1 ] .grid (True , alpha = 0.3 , axis = "y" )
947+ ax .legend (fontsize = LEGEND_FONT_SIZE , framealpha = 0.9 )
948+ ax .grid (True , alpha = 0.3 , axis = "y" )
937949
938950 # Add parameter text
939951 param_text = "Estimated:\n "
@@ -948,28 +960,16 @@ def plot_transforms(
948960 param_text += f" half_life={ half_life_true :.2f} \n "
949961 param_text += f" alpha={ true_alpha :.3f} \n "
950962
951- axes [ 1 ] .text (
963+ ax .text (
952964 0.95 ,
953965 0.95 ,
954966 param_text .strip (),
955- transform = axes [ 1 ] .transAxes ,
967+ transform = ax .transAxes ,
956968 fontsize = 9 ,
957969 verticalalignment = "top" ,
958970 horizontalalignment = "right" ,
959971 bbox = dict (boxstyle = "round" , facecolor = "wheat" , alpha = 0.5 ),
960972 )
961- else :
962- axes [1 ].text (
963- 0.5 ,
964- 0.5 ,
965- "No adstock transform" ,
966- ha = "center" ,
967- va = "center" ,
968- transform = axes [1 ].transAxes ,
969- )
970- axes [1 ].set_title (
971- "Adstock Function (Carryover Effect)" , fontsize = 12 , fontweight = "bold"
972- )
973973
974974 plt .tight_layout ()
975975 return fig , axes
0 commit comments