Skip to content

Commit 01db136

Browse files
committed
feat: add prob method to to the categorical distribution
Signed-off-by: Louis Mandel <[email protected]>
1 parent 434b49f commit 01db136

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

src/pdl/pdl_distributions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ def sort(self) -> "Categorical[T]":
5151
d.metadata = [d.metadata[i] for i in sorted_indices]
5252
return d
5353

54+
def prob(self, x: T) -> float:
55+
dist = self.shrink()
56+
try:
57+
i = dist.values.index(x)
58+
p = dist.probs[i]
59+
except ValueError:
60+
p = 0.0
61+
return p
62+
5463

5564
def viz(dist: Categorical[float], **kwargs):
5665
"""

0 commit comments

Comments
 (0)