@@ -308,9 +308,7 @@ def create_idata_attrs(self) -> dict[str, str]:
308
308
309
309
return attrs
310
310
311
- def forward_pass (
312
- self , x : pt .TensorVariable | npt .NDArray [np .float64 ]
313
- ) -> pt .TensorVariable :
311
+ def forward_pass (self , x : pt .TensorVariable | npt .NDArray ) -> pt .TensorVariable :
314
312
"""Transform channel input into target contributions of each channel.
315
313
316
314
This method handles the ordering of the adstock and saturation
@@ -322,7 +320,7 @@ def forward_pass(
322
320
323
321
Parameters
324
322
----------
325
- x : pt.TensorVariable | npt.NDArray[np.float64]
323
+ x : pt.TensorVariable | npt.NDArray
326
324
The channel input which could be spends or impressions
327
325
328
326
Returns
@@ -586,9 +584,9 @@ def default_model_config(self) -> dict:
586
584
587
585
def channel_contributions_forward_pass (
588
586
self ,
589
- channel_data : npt .NDArray [ np . float64 ] ,
587
+ channel_data : npt .NDArray ,
590
588
disable_logger_stdout : bool | None = False ,
591
- ) -> npt .NDArray [ np . float64 ] :
589
+ ) -> npt .NDArray :
592
590
"""Evaluate the channel contribution for a given channel data and a fitted model, ie. the forward pass.
593
591
594
592
Parameters
@@ -945,9 +943,9 @@ class MMM(
945
943
946
944
def channel_contributions_forward_pass (
947
945
self ,
948
- channel_data : npt .NDArray [ np . float64 ] ,
946
+ channel_data : npt .NDArray ,
949
947
disable_logger_stdout : bool | None = False ,
950
- ) -> npt .NDArray [ np . float64 ] :
948
+ ) -> npt .NDArray :
951
949
"""Evaluate the channel contribution for a given channel data and a fitted model, ie. the forward pass.
952
950
953
951
We return the contribution in the original scale of the target variable.
0 commit comments