|
14 | 14 | # ============================================================================== |
15 | 15 | """Quantized distribution.""" |
16 | 16 |
|
17 | | -from typing import cast, Optional, Tuple |
| 17 | +from typing import Optional, Tuple |
18 | 18 |
|
19 | 19 | import chex |
20 | 20 | from distrax._src.distributions import distribution as base_distribution |
@@ -61,8 +61,7 @@ def __init__(self, |
61 | 61 | `distribution` and must not result in additional batch dimensions after |
62 | 62 | broadcasting. |
63 | 63 | """ |
64 | | - self._dist: base_distribution.Distribution[Array, Tuple[ |
65 | | - int, ...], jnp.dtype] = conversion.as_distribution(distribution) |
| 64 | + self._dist = conversion.as_distribution(distribution) |
66 | 65 | if self._dist.event_shape: |
67 | 66 | raise ValueError(f'The base distribution must be univariate, but its ' |
68 | 67 | f'`event_shape` is {self._dist.event_shape}.') |
@@ -107,9 +106,7 @@ def high(self) -> Optional[Array]: |
107 | 106 | @property |
108 | 107 | def event_shape(self) -> Tuple[int, ...]: |
109 | 108 | """Shape of event of distribution samples.""" |
110 | | - event_shape = self.distribution.event_shape |
111 | | - # TODO(b/149413467): Remove explicit casting when resolved. |
112 | | - return cast(Tuple[int, ...], event_shape) |
| 109 | + return self.distribution.event_shape |
113 | 110 |
|
114 | 111 | @property |
115 | 112 | def batch_shape(self) -> Tuple[int, ...]: |
|
0 commit comments