Skip to content

Commit 269dd75

Browse files
committed
dist and rv init commit
1 parent 7d62c53 commit 269dd75

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

pymc_extras/distributions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@
3737
"R2D2M2CP",
3838
"Skellam",
3939
"histogram_approximation",
40+
"GrassiaIIGeometric",
4041
]

pymc_extras/distributions/discrete.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,3 +397,103 @@ def dist(cls, mu1, mu2, **kwargs):
397397
class_name="Skellam",
398398
**kwargs,
399399
)
400+
401+
402+
class GrassiaIIGeometricRV(RandomVariable):
403+
name = "g2g"
404+
signature = "(),()->()"
405+
406+
dtype = "int64"
407+
_print_name = ("GrassiaIIGeometric", "\\operatorname{GrassiaIIGeometric}")
408+
409+
def __call__(self, r, alpha, size=None, **kwargs):
410+
return super().__call__(r, alpha, size=size, **kwargs)
411+
412+
@classmethod
413+
def rng_fn(cls, rng, r, alpha, size):
414+
if size is None:
415+
size = np.broadcast_shapes(r.shape, alpha.shape)
416+
417+
r = np.asarray(r)
418+
alpha = np.asarray(alpha)
419+
420+
r = np.broadcast_to(r, size)
421+
alpha = np.broadcast_to(alpha, size)
422+
423+
output = np.zeros(shape=size + (1,)) # noqa:RUF005
424+
425+
lam = rng.gamma(shape=r, scale=1 / alpha, size=size)
426+
427+
def sim_data(lam):
428+
# TODO: To support time-varying covariates, covariate vector may need to be added
429+
p = 1 - np.exp(-lam)
430+
431+
t = rng.geometric(p)
432+
433+
return np.array([t])
434+
435+
for index in np.ndindex(*size):
436+
output[index] = sim_data(lam[index])
437+
438+
return output
439+
440+
441+
g2g = GrassiaIIGeometricRV()
442+
443+
444+
class GrassiaIIGeometric(UnitContinuous):
445+
r"""Grassia(II)-Geometric distribution for a discrete-time, contractual customer population.
446+
447+
Described by Hardie and Fader in [1]_, this distribution is comprised by the following PMF and survival functions:
448+
449+
.. math::
450+
\mathbb{P}T=t|r,\alpha,\beta;Z(t)) = (\frac{\alpha}{\alpha+C(t-1)})^{r} - (\frac{\alpha}{\alpha+C(t)})^{r} \\
451+
\begin{align}
452+
\mathbb{S}(t|r,\alpha,\beta;Z(t)) = (\frac{\alpha}{\alpha+C(t)})^{r} \\
453+
\end{align}
454+
======== ===============================================
455+
Support :math:`0 < t <= T` for :math: `t = 1, 2, \dots, T`
456+
======== ===============================================
457+
458+
Parameters
459+
----------
460+
r : tensor_like of float
461+
Shape parameter of Gamma distribution describing customer heterogeneity. (r > 0)
462+
alpha : tensor_like of float
463+
Scale parameter of Gamma distribution describing customer heterogeneity. (alpha > 0)
464+
465+
References
466+
----------
467+
.. [1] Fader, Peter & G. S. Hardie, Bruce (2020).
468+
"Incorporating Time-Varying Covariates in a Simple Mixture Model for Discrete-Time Duration Data."
469+
https://www.brucehardie.com/notes/037/time-varying_covariates_in_BG.pdf
470+
"""
471+
472+
rv_op = g2g
473+
474+
@classmethod
475+
def dist(cls, r, alpha, *args, **kwargs):
476+
r = pt.as_tensor_variable(r)
477+
alpha = pt.as_tensor_variable(alpha)
478+
return super().dist([r, alpha], *args, **kwargs)
479+
480+
def logp(value, r, alpha):
481+
logp = -r * (pt.log(alpha + value - 1) + pt.log(alpha + value))
482+
483+
return check_parameters(
484+
logp,
485+
r > 0,
486+
alpha > 0,
487+
msg="s > 0, alpha > 0",
488+
)
489+
490+
def logcdf(value, r, alpha):
491+
# TODO: Math may not be correct here
492+
logcdf = r * (pt.log(value) - pt.log(alpha + value))
493+
494+
return check_parameters(
495+
logcdf,
496+
r > 0,
497+
alpha > 0,
498+
msg="s > 0, alpha > 0",
499+
)

0 commit comments

Comments
 (0)