Skip to content

Commit c369091

Browse files
hamzamerzicDistraxDev
authored andcommitted
Remove explicit type casting in quantized.
PiperOrigin-RevId: 462610652
1 parent 0ecad05 commit c369091

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

distrax/_src/distributions/quantized.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ==============================================================================
1515
"""Quantized distribution."""
1616

17-
from typing import cast, Optional, Tuple
17+
from typing import Optional, Tuple
1818

1919
import chex
2020
from distrax._src.distributions import distribution as base_distribution
@@ -61,8 +61,7 @@ def __init__(self,
6161
`distribution` and must not result in additional batch dimensions after
6262
broadcasting.
6363
"""
64-
self._dist: base_distribution.Distribution[Array, Tuple[
65-
int, ...], jnp.dtype] = conversion.as_distribution(distribution)
64+
self._dist = conversion.as_distribution(distribution)
6665
if self._dist.event_shape:
6766
raise ValueError(f'The base distribution must be univariate, but its '
6867
f'`event_shape` is {self._dist.event_shape}.')
@@ -107,9 +106,7 @@ def high(self) -> Optional[Array]:
107106
@property
108107
def event_shape(self) -> Tuple[int, ...]:
109108
"""Shape of event of distribution samples."""
110-
event_shape = self.distribution.event_shape
111-
# TODO(b/149413467): Remove explicit casting when resolved.
112-
return cast(Tuple[int, ...], event_shape)
109+
return self.distribution.event_shape
113110

114111
@property
115112
def batch_shape(self) -> Tuple[int, ...]:

0 commit comments

Comments
 (0)