14
14
"""Media Mix Model class."""
15
15
16
16
import json
17
+ import logging
17
18
import warnings
18
19
from typing import Annotated , Any , Literal
19
20
@@ -582,21 +583,29 @@ def default_model_config(self) -> dict:
582
583
}
583
584
584
585
def channel_contributions_forward_pass (
585
- self , channel_data : npt .NDArray [np .float64 ]
586
+ self ,
587
+ channel_data : npt .NDArray [np .float64 ],
588
+ disable_logger_stdout : bool | None = False ,
586
589
) -> npt .NDArray [np .float64 ]:
587
590
"""Evaluate the channel contribution for a given channel data and a fitted model, ie. the forward pass.
588
591
589
592
Parameters
590
593
----------
591
594
channel_data : array-like
592
595
Input channel data. Result of all the preprocessing steps.
596
+ disable_logger_stdout : bool, optional
597
+ If True, suppress logger output to stdout
593
598
594
599
Returns
595
600
-------
596
601
array-like
597
602
Transformed channel data.
598
603
599
604
"""
605
+ if disable_logger_stdout :
606
+ logger = logging .getLogger ("pymc.sampling.forward" )
607
+ logger .propagate = False
608
+
600
609
coords = {
601
610
** self .model_coords ,
602
611
}
@@ -925,7 +934,9 @@ class MMM(
925
934
version : str = "0.0.2"
926
935
927
936
def channel_contributions_forward_pass (
928
- self , channel_data : npt .NDArray [np .float64 ]
937
+ self ,
938
+ channel_data : npt .NDArray [np .float64 ],
939
+ disable_logger_stdout : bool | None = False ,
929
940
) -> npt .NDArray [np .float64 ]:
930
941
"""Evaluate the channel contribution for a given channel data and a fitted model, ie. the forward pass.
931
942
@@ -935,6 +946,8 @@ def channel_contributions_forward_pass(
935
946
----------
936
947
channel_data : array-like
937
948
Input channel data. Result of all the preprocessing steps.
949
+ disable_logger_stdout : bool, optional
950
+ If True, suppress logger output to stdout
938
951
939
952
Returns
940
953
-------
@@ -943,7 +956,7 @@ def channel_contributions_forward_pass(
943
956
944
957
"""
945
958
channel_contribution_forward_pass = super ().channel_contributions_forward_pass (
946
- channel_data = channel_data
959
+ channel_data = channel_data , disable_logger_stdout = disable_logger_stdout
947
960
)
948
961
target_transformed_vectorized = np .vectorize (
949
962
self .target_transformer .inverse_transform ,
@@ -983,7 +996,7 @@ def get_channel_contributions_forward_pass_grid(
983
996
delta * self .preprocessed_data ["X" ][self .channel_columns ].to_numpy ()
984
997
)
985
998
channel_contribution_forward_pass = self .channel_contributions_forward_pass (
986
- channel_data = channel_data
999
+ channel_data = channel_data , disable_logger_stdout = True
987
1000
)
988
1001
channel_contributions .append (channel_contribution_forward_pass )
989
1002
return DataArray (
0 commit comments