Skip to content

Commit d0070ca

Browse files
authored
Fill docstrings (#1043)
1 parent 7975bc2 commit d0070ca

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

pymc_marketing/mmm/mmm.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Media Mix Model class."""
1515

1616
import json
17+
import logging
1718
import warnings
1819
from typing import Annotated, Any, Literal
1920

@@ -582,21 +583,29 @@ def default_model_config(self) -> dict:
582583
}
583584

584585
def channel_contributions_forward_pass(
585-
self, channel_data: npt.NDArray[np.float64]
586+
self,
587+
channel_data: npt.NDArray[np.float64],
588+
disable_logger_stdout: bool | None = False,
586589
) -> npt.NDArray[np.float64]:
587590
"""Evaluate the channel contribution for a given channel data and a fitted model, ie. the forward pass.
588591
589592
Parameters
590593
----------
591594
channel_data : array-like
592595
Input channel data. Result of all the preprocessing steps.
596+
disable_logger_stdout : bool, optional
597+
If True, suppress logger output to stdout
593598
594599
Returns
595600
-------
596601
array-like
597602
Transformed channel data.
598603
599604
"""
605+
if disable_logger_stdout:
606+
logger = logging.getLogger("pymc.sampling.forward")
607+
logger.propagate = False
608+
600609
coords = {
601610
**self.model_coords,
602611
}
@@ -925,7 +934,9 @@ class MMM(
925934
version: str = "0.0.2"
926935

927936
def channel_contributions_forward_pass(
928-
self, channel_data: npt.NDArray[np.float64]
937+
self,
938+
channel_data: npt.NDArray[np.float64],
939+
disable_logger_stdout: bool | None = False,
929940
) -> npt.NDArray[np.float64]:
930941
"""Evaluate the channel contribution for a given channel data and a fitted model, ie. the forward pass.
931942
@@ -935,6 +946,8 @@ def channel_contributions_forward_pass(
935946
----------
936947
channel_data : array-like
937948
Input channel data. Result of all the preprocessing steps.
949+
disable_logger_stdout : bool, optional
950+
If True, suppress logger output to stdout
938951
939952
Returns
940953
-------
@@ -943,7 +956,7 @@ def channel_contributions_forward_pass(
943956
944957
"""
945958
channel_contribution_forward_pass = super().channel_contributions_forward_pass(
946-
channel_data=channel_data
959+
channel_data=channel_data, disable_logger_stdout=disable_logger_stdout
947960
)
948961
target_transformed_vectorized = np.vectorize(
949962
self.target_transformer.inverse_transform,
@@ -983,7 +996,7 @@ def get_channel_contributions_forward_pass_grid(
983996
delta * self.preprocessed_data["X"][self.channel_columns].to_numpy()
984997
)
985998
channel_contribution_forward_pass = self.channel_contributions_forward_pass(
986-
channel_data=channel_data
999+
channel_data=channel_data, disable_logger_stdout=True
9871000
)
9881001
channel_contributions.append(channel_contribution_forward_pass)
9891002
return DataArray(

0 commit comments

Comments
 (0)