Skip to content

Commit 6f66261

Browse files
committed
fix: shrink method of Categorical distribution
Signed-off-by: Louis Mandel <[email protected]>
1 parent 01db136 commit 6f66261

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/pdl/pdl_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def shrink(self) -> "Categorical[T]":
3535
res[v] = (w_v + w, m_v + m)
3636
else:
3737
res[v] = (w, m)
38-
return Categorical([(v, w, m) for v, (w, m) in res.items()])
38+
return Categorical([(v, np.log(w), m) for v, (w, m) in res.items()])
3939

4040
def sample(self) -> T:
4141
u = rand.rand()

0 commit comments

Comments
 (0)