Skip to content

Commit d3d039f

Browse files
tomwardioDistraxDev
authored andcommitted
Update chex.assert_type to check concrete types instead of just asserting that the type is a floating/integer sub-type.
Previously, `assert_type` would only check that the input was of the same parent type. For example: ``` x = np.ones((1,), dtype=np.float32) chex.assert_type(x, np.float64) # Succeeds chex.assert_type(x, np.int32) # Fails. ``` Instead, if a concrete dtype is provided we check that the input has the same type. If `float` or `np.floating` is provided, we continue to only assert that the input is the same parent. ``` x = np.ones((1,), dtype=np.float32) chex.assert_type(x, np.float64) # Fails chex.assert_type(x, float) # Succeeds. ``` PiperOrigin-RevId: 607102995
1 parent 7c0e1bf commit d3d039f

15 files changed

+114
-82
lines changed

distrax/_src/distributions/deterministic_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import chex
2121
from distrax._src.distributions import deterministic
2222
from distrax._src.utils import equivalence
23+
import jax.experimental
2324
import jax.numpy as jnp
2425
import numpy as np
2526

@@ -107,10 +108,11 @@ def test_sample_shape(self, loc, sample_shape):
107108
('float32', jnp.float32),
108109
('float64', jnp.float64))
109110
def test_sample_dtype(self, dtype):
110-
dist = self.distrax_cls(loc=jnp.zeros((), dtype=dtype))
111-
samples = self.variant(dist.sample)(seed=self.key)
112-
self.assertEqual(samples.dtype, dist.dtype)
113-
chex.assert_type(samples, dtype)
111+
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
112+
dist = self.distrax_cls(loc=jnp.zeros((), dtype=dtype))
113+
samples = self.variant(dist.sample)(seed=self.key)
114+
self.assertEqual(samples.dtype, dist.dtype)
115+
chex.assert_type(samples, dtype)
114116

115117
@chex.all_variants
116118
@parameterized.named_parameters(

distrax/_src/distributions/epsilon_greedy_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import chex
2323
from distrax._src.distributions import epsilon_greedy
2424
from distrax._src.utils import equivalence
25+
import jax.experimental
2526
import jax.numpy as jnp
2627
import numpy as np
2728

@@ -51,11 +52,12 @@ def test_num_categories(self):
5152
('float32', jnp.float32),
5253
('float64', jnp.float64))
5354
def test_sample_dtype(self, dtype):
54-
dist = self.distrax_cls(
55-
preferences=self.preferences, epsilon=self.epsilon, dtype=dtype)
56-
samples = self.variant(dist.sample)(seed=self.key)
57-
self.assertEqual(samples.dtype, dist.dtype)
58-
chex.assert_type(samples, dtype)
55+
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
56+
dist = self.distrax_cls(
57+
preferences=self.preferences, epsilon=self.epsilon, dtype=dtype)
58+
samples = self.variant(dist.sample)(seed=self.key)
59+
self.assertEqual(samples.dtype, dist.dtype)
60+
chex.assert_type(samples, dtype)
5961

6062
def test_jittable(self):
6163
super()._test_jittable(

distrax/_src/distributions/gamma_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import chex
2121
from distrax._src.distributions import gamma
2222
from distrax._src.utils import equivalence
23+
import jax.experimental
2324
import jax.numpy as jnp
2425
import numpy as np
2526

@@ -73,11 +74,12 @@ def test_sample_shape(self, distr_params, sample_shape):
7374
('float32', jnp.float32),
7475
('float64', jnp.float64))
7576
def test_sample_dtype(self, dtype):
76-
dist = self.distrax_cls(
77-
concentration=jnp.ones((), dtype), rate=jnp.ones((), dtype))
78-
samples = self.variant(dist.sample)(seed=self.key)
79-
self.assertEqual(samples.dtype, dist.dtype)
80-
chex.assert_type(samples, dtype)
77+
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
78+
dist = self.distrax_cls(
79+
concentration=jnp.ones((), dtype), rate=jnp.ones((), dtype))
80+
samples = self.variant(dist.sample)(seed=self.key)
81+
self.assertEqual(samples.dtype, dist.dtype)
82+
chex.assert_type(samples, dtype)
8183

8284
@chex.all_variants
8385
@parameterized.named_parameters(

distrax/_src/distributions/greedy_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import chex
2121
from distrax._src.distributions import greedy
2222
from distrax._src.utils import equivalence
23+
import jax.experimental
2324
import jax.numpy as jnp
2425
import numpy as np
2526

@@ -48,10 +49,11 @@ def test_num_categories(self):
4849
('float32', jnp.float32),
4950
('float64', jnp.float64))
5051
def test_sample_dtype(self, dtype):
51-
dist = self.distrax_cls(preferences=self.preferences, dtype=dtype)
52-
samples = self.variant(dist.sample)(seed=self.key)
53-
self.assertEqual(samples.dtype, dist.dtype)
54-
chex.assert_type(samples, dtype)
52+
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
53+
dist = self.distrax_cls(preferences=self.preferences, dtype=dtype)
54+
samples = self.variant(dist.sample)(seed=self.key)
55+
self.assertEqual(samples.dtype, dist.dtype)
56+
chex.assert_type(samples, dtype)
5557

5658
def test_jittable(self):
5759
super()._test_jittable((np.array([0., 4., -1., 4.]),))

distrax/_src/distributions/gumbel_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import chex
2121
from distrax._src.distributions import gumbel
2222
from distrax._src.utils import equivalence
23+
import jax.experimental
2324
import jax.numpy as jnp
2425
import numpy as np
2526

@@ -67,11 +68,12 @@ def test_sample_shape(self, distr_params, sample_shape):
6768
('float32', jnp.float32),
6869
('float64', jnp.float64))
6970
def test_sample_dtype(self, dtype):
70-
dist = self.distrax_cls(
71-
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
72-
samples = self.variant(dist.sample)(seed=self.key)
73-
self.assertEqual(samples.dtype, dist.dtype)
74-
chex.assert_type(samples, dtype)
71+
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
72+
dist = self.distrax_cls(
73+
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
74+
samples = self.variant(dist.sample)(seed=self.key)
75+
self.assertEqual(samples.dtype, dist.dtype)
76+
chex.assert_type(samples, dtype)
7577

7678
@chex.all_variants
7779
@parameterized.named_parameters(

distrax/_src/distributions/laplace_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import chex
2121
from distrax._src.distributions import laplace
2222
from distrax._src.utils import equivalence
23+
import jax.experimental
2324
import jax.numpy as jnp
2425
import numpy as np
2526

@@ -65,11 +66,12 @@ def test_sample_shape(self, distr_params, sample_shape):
6566
('float32', jnp.float32),
6667
('float64', jnp.float64))
6768
def test_sample_dtype(self, dtype):
68-
dist = self.distrax_cls(
69-
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
70-
samples = self.variant(dist.sample)(seed=self.key)
71-
self.assertEqual(samples.dtype, dist.dtype)
72-
chex.assert_type(samples, dtype)
69+
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
70+
dist = self.distrax_cls(
71+
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
72+
samples = self.variant(dist.sample)(seed=self.key)
73+
self.assertEqual(samples.dtype, dist.dtype)
74+
chex.assert_type(samples, dtype)
7375

7476
@chex.all_variants
7577
@parameterized.named_parameters(

distrax/_src/distributions/log_stddev_normal_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from distrax._src.distributions import log_stddev_normal as lsn
2222
from distrax._src.distributions import normal
2323
import jax
24+
import jax.experimental
2425
import jax.numpy as jnp
2526
import mock
2627
import numpy as np
@@ -105,11 +106,12 @@ def test_sampling_batched_custom_dim(self):
105106
('float32', jnp.float32),
106107
('float64', jnp.float64))
107108
def test_sample_dtype(self, dtype):
108-
dist = lsn.LogStddevNormal(
109-
loc=jnp.zeros((), dtype), log_scale=jnp.zeros((), dtype))
110-
samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0))
111-
self.assertEqual(samples.dtype, dist.dtype)
112-
chex.assert_type(samples, dtype)
109+
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
110+
dist = lsn.LogStddevNormal(
111+
loc=jnp.zeros((), dtype), log_scale=jnp.zeros((), dtype))
112+
samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0))
113+
self.assertEqual(samples.dtype, dist.dtype)
114+
chex.assert_type(samples, dtype)
113115

114116
def test_kl_versus_normal(self):
115117
loc, scale = jnp.array([2.0]), jnp.array([2.0])

distrax/_src/distributions/logistic_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import chex
2121
from distrax._src.distributions import logistic
2222
from distrax._src.utils import equivalence
23+
import jax.experimental
2324
import jax.numpy as jnp
2425
import numpy as np
2526

@@ -66,11 +67,12 @@ def test_sample_shape(self, distr_params, sample_shape):
6667
('float32', jnp.float32),
6768
('float64', jnp.float64))
6869
def test_sample_dtype(self, dtype):
69-
dist = self.distrax_cls(
70-
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
71-
samples = self.variant(dist.sample)(seed=self.key)
72-
self.assertEqual(samples.dtype, dist.dtype)
73-
chex.assert_type(samples, dtype)
70+
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
71+
dist = self.distrax_cls(
72+
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
73+
samples = self.variant(dist.sample)(seed=self.key)
74+
self.assertEqual(samples.dtype, dist.dtype)
75+
chex.assert_type(samples, dtype)
7476

7577
@chex.all_variants
7678
@parameterized.named_parameters(

distrax/_src/distributions/multinomial_test.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from distrax._src.utils import equivalence
2323
from distrax._src.utils import math
2424
import jax
25+
import jax.experimental
2526
import jax.numpy as jnp
2627
import numpy as np
2728
from scipy import stats
@@ -405,12 +406,16 @@ def test_sample_and_log_prob(self, dist_params, sample_shape):
405406
('float32', jnp.float32),
406407
('float64', jnp.float64))
407408
def test_sample_dtype(self, dtype):
408-
dist_params = {
409-
'logits': self.logits, 'dtype': dtype, 'total_count': self.total_count}
410-
dist = self.distrax_cls(**dist_params)
411-
samples = self.variant(dist.sample)(seed=self.key)
412-
self.assertEqual(samples.dtype, dist.dtype)
413-
chex.assert_type(samples, dtype)
409+
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
410+
dist_params = {
411+
'logits': self.logits,
412+
'dtype': dtype,
413+
'total_count': self.total_count,
414+
}
415+
dist = self.distrax_cls(**dist_params)
416+
samples = self.variant(dist.sample)(seed=self.key)
417+
self.assertEqual(samples.dtype, dist.dtype)
418+
chex.assert_type(samples, dtype)
414419

415420
@chex.all_variants
416421
def test_sample_extreme_probs(self):

distrax/_src/distributions/mvn_diag_plus_low_rank_test.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from distrax._src.utils import equivalence
2323

2424
import jax
25+
import jax.experimental
2526
import jax.numpy as jnp
2627
import numpy as np
2728
from tensorflow_probability.substrates import jax as tfp
@@ -180,13 +181,14 @@ def test_sample_shape(self, sample_shape, loc_shape, scale_diag_shape,
180181
('float32', jnp.float32),
181182
('float64', jnp.float64))
182183
def test_sample_dtype(self, dtype):
183-
dist_params = {
184-
'loc': np.array([0., 0.], dtype),
185-
'scale_diag': np.array([1., 1.], dtype)}
186-
dist = MultivariateNormalDiagPlusLowRank(**dist_params)
187-
samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0))
188-
self.assertEqual(samples.dtype, dist.dtype)
189-
chex.assert_type(samples, dtype)
184+
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
185+
dist_params = {
186+
'loc': np.array([0., 0.], dtype),
187+
'scale_diag': np.array([1., 1.], dtype)}
188+
dist = MultivariateNormalDiagPlusLowRank(**dist_params)
189+
samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0))
190+
self.assertEqual(samples.dtype, dist.dtype)
191+
chex.assert_type(samples, dtype)
190192

191193
@chex.all_variants
192194
@parameterized.named_parameters(

0 commit comments

Comments
 (0)