Skip to content

Commit 66d8b32

Browse files
authored
Fix: Use **kwargs instead of kwargs parameter in MNLogit methods (#2131)
Updated method signatures for sample_prior_predictive, fit, and sample_posterior_predictive to use **kwargs directly instead of accepting kwargs as a regular parameter. This allows proper keyword argument unpacking when calling PyMC sampling functions. Also updated the sample() method to unpack the kwargs dictionaries when calling these methods. Fixes #2111
1 parent 2259553 commit 66d8b32

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

pymc_marketing/customer_choice/mnl_logit.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def create_idata_attrs(self) -> dict[str, str]:
373373

374374
return attrs
375375

376-
def sample_prior_predictive(self, extend_idata, kwargs):
376+
def sample_prior_predictive(self, extend_idata, **kwargs):
377377
"""Sample Prior Predictive Distribution."""
378378
with self.model: # sample with new input data
379379
prior_pred: az.InferenceData = pm.sample_prior_predictive(500, **kwargs)
@@ -385,7 +385,7 @@ def sample_prior_predictive(self, extend_idata, kwargs):
385385
else:
386386
self.idata = prior_pred
387387

388-
def fit(self, extend_idata, kwargs):
388+
def fit(self, extend_idata, **kwargs):
389389
"""Fit Nested Logit Model."""
390390
if extend_idata:
391391
with self.model:
@@ -394,7 +394,7 @@ def fit(self, extend_idata, kwargs):
394394
with self.model:
395395
self.idata = pm.sample(**kwargs)
396396

397-
def sample_posterior_predictive(self, extend_idata, kwargs):
397+
def sample_posterior_predictive(self, extend_idata, **kwargs):
398398
"""Sample Posterior Predictive Distribution."""
399399
if extend_idata:
400400
with self.model:
@@ -448,11 +448,11 @@ def sample(
448448
self.model = model
449449

450450
self.sample_prior_predictive(
451-
extend_idata=True, kwargs=sample_prior_predictive_kwargs
451+
extend_idata=True, **sample_prior_predictive_kwargs
452452
)
453-
self.fit(extend_idata=True, kwargs=fit_kwargs)
453+
self.fit(extend_idata=True, **fit_kwargs)
454454
self.sample_posterior_predictive(
455-
extend_idata=True, kwargs=sample_posterior_predictive_kwargs
455+
extend_idata=True, **sample_posterior_predictive_kwargs
456456
)
457457
return self
458458

0 commit comments

Comments
 (0)