@@ -5765,11 +5765,8 @@ def test_add_python_types(self, dtype):
5765
5765
5766
5766
@parameterized .named_parameters (named_product (dtype = ALL_DTYPES ))
5767
5767
def test_bartlett (self , dtype ):
5768
- import jax .numpy as jnp
5769
-
5770
5768
x = knp .ones ((), dtype = dtype )
5771
- x_jax = jnp .ones ((), dtype = dtype )
5772
- expected_dtype = standardize_dtype (jnp .bartlett (x_jax ).dtype )
5769
+ expected_dtype = backend .floatx ()
5773
5770
5774
5771
self .assertEqual (
5775
5772
standardize_dtype (knp .bartlett (x ).dtype ), expected_dtype
@@ -5781,11 +5778,8 @@ def test_bartlett(self, dtype):
5781
5778
5782
5779
@parameterized .named_parameters (named_product (dtype = ALL_DTYPES ))
5783
5780
def test_blackman (self , dtype ):
5784
- import jax .numpy as jnp
5785
-
5786
5781
x = knp .ones ((), dtype = dtype )
5787
- x_jax = jnp .ones ((), dtype = dtype )
5788
- expected_dtype = standardize_dtype (jnp .blackman (x_jax ).dtype )
5782
+ expected_dtype = backend .floatx ()
5789
5783
5790
5784
self .assertEqual (
5791
5785
standardize_dtype (knp .blackman (x ).dtype ), expected_dtype
@@ -5797,11 +5791,8 @@ def test_blackman(self, dtype):
5797
5791
5798
5792
@parameterized .named_parameters (named_product (dtype = ALL_DTYPES ))
5799
5793
def test_hamming (self , dtype ):
5800
- import jax .numpy as jnp
5801
-
5802
5794
x = knp .ones ((), dtype = dtype )
5803
- x_jax = jnp .ones ((), dtype = dtype )
5804
- expected_dtype = standardize_dtype (jnp .hamming (x_jax ).dtype )
5795
+ expected_dtype = backend .floatx ()
5805
5796
5806
5797
self .assertEqual (
5807
5798
standardize_dtype (knp .hamming (x ).dtype ), expected_dtype
@@ -5813,11 +5804,8 @@ def test_hamming(self, dtype):
5813
5804
5814
5805
@parameterized .named_parameters (named_product (dtype = ALL_DTYPES ))
5815
5806
def test_hanning (self , dtype ):
5816
- import jax .numpy as jnp
5817
-
5818
5807
x = knp .ones ((), dtype = dtype )
5819
- x_jax = jnp .ones ((), dtype = dtype )
5820
- expected_dtype = standardize_dtype (jnp .hanning (x_jax ).dtype )
5808
+ expected_dtype = backend .floatx ()
5821
5809
5822
5810
self .assertEqual (
5823
5811
standardize_dtype (knp .hanning (x ).dtype ), expected_dtype
@@ -5829,13 +5817,9 @@ def test_hanning(self, dtype):
5829
5817
5830
5818
@parameterized .named_parameters (named_product (dtype = ALL_DTYPES ))
5831
5819
def test_kaiser (self , dtype ):
5832
- import jax .numpy as jnp
5833
-
5834
5820
x = knp .ones ((), dtype = dtype )
5835
5821
beta = knp .ones ((), dtype = dtype )
5836
- x_jax = jnp .ones ((), dtype = dtype )
5837
- beta_jax = jnp .ones ((), dtype = dtype )
5838
- expected_dtype = standardize_dtype (jnp .kaiser (x_jax , beta_jax ).dtype )
5822
+ expected_dtype = backend .floatx ()
5839
5823
5840
5824
self .assertEqual (
5841
5825
standardize_dtype (knp .kaiser (x , beta ).dtype ), expected_dtype
@@ -7268,13 +7252,17 @@ def test_eye(self, dtype):
7268
7252
import jax .numpy as jnp
7269
7253
7270
7254
expected_dtype = standardize_dtype (jnp .eye (3 , dtype = dtype ).dtype )
7255
+ if dtype is None :
7256
+ expected_dtype = backend .floatx ()
7271
7257
7272
7258
self .assertEqual (
7273
7259
standardize_dtype (knp .eye (3 , dtype = dtype ).dtype ),
7274
7260
expected_dtype ,
7275
7261
)
7276
7262
7277
7263
expected_dtype = standardize_dtype (jnp .eye (3 , 4 , 1 , dtype = dtype ).dtype )
7264
+ if dtype is None :
7265
+ expected_dtype = backend .floatx ()
7278
7266
7279
7267
self .assertEqual (
7280
7268
standardize_dtype (knp .eye (3 , 4 , k = 1 , dtype = dtype ).dtype ),
@@ -7506,6 +7494,8 @@ def test_identity(self, dtype):
7506
7494
import jax .numpy as jnp
7507
7495
7508
7496
expected_dtype = standardize_dtype (jnp .identity (3 , dtype = dtype ).dtype )
7497
+ if dtype is None :
7498
+ expected_dtype = backend .floatx ()
7509
7499
7510
7500
self .assertEqual (
7511
7501
standardize_dtype (knp .identity (3 , dtype = dtype ).dtype ),
@@ -9141,7 +9131,7 @@ def test_angle(self, dtype):
9141
9131
x = knp .ones ((1 ,), dtype = dtype )
9142
9132
x_jax = jnp .ones ((1 ,), dtype = dtype )
9143
9133
expected_dtype = standardize_dtype (jnp .angle (x_jax ).dtype )
9144
- if dtype == "int64" :
9134
+ if dtype == "bool" or is_int_dtype ( dtype ) :
9145
9135
expected_dtype = backend .floatx ()
9146
9136
9147
9137
self .assertEqual (standardize_dtype (knp .angle (x ).dtype ), expected_dtype )
0 commit comments