@@ -80,14 +80,9 @@ def test_stft_initializer(self):
8080 shape = (256 , 1 , 513 )
8181 time_range = np .arange (256 ).reshape ((- 1 , 1 , 1 ))
8282 freq_range = (np .arange (513 ) / 1024.0 ).reshape ((1 , 1 , - 1 ))
83- pi = np .arccos (np .float64 (- 1 ))
83+ pi = np .arccos (np .float32 (- 1 ))
8484 args = - 2 * pi * time_range * freq_range
85-
86- tol_kwargs = {}
87- if backend .backend () == "jax" :
88- # TODO(mostafa-mahmoud): investigate the cases
89- # of non-small error in jax and torch
90- tol_kwargs = {"atol" : 1e-4 , "rtol" : 1e-6 }
85+ tol_kwargs = {"atol" : 1e-4 , "rtol" : 1e-6 }
9186
9287 initializer = initializers .STFT ("real" , None )
9388 values = backend .convert_to_numpy (initializer (shape ))
@@ -101,8 +96,8 @@ def test_stft_initializer(self):
10196 True ,
10297 )
10398 window = scipy .signal .windows .get_window ("hamming" , 256 , True )
104- window = window .astype ("float64 " ).reshape ((- 1 , 1 , 1 ))
105- values = backend .convert_to_numpy (initializer (shape , "float64 " ))
99+ window = window .astype ("float32 " ).reshape ((- 1 , 1 , 1 ))
100+ values = backend .convert_to_numpy (initializer (shape , "float32 " ))
106101 self .assertAllClose (np .cos (args ) * window , values , ** tol_kwargs )
107102 self .run_class_serialization_test (initializer )
108103
@@ -113,9 +108,9 @@ def test_stft_initializer(self):
113108 False ,
114109 )
115110 window = scipy .signal .windows .get_window ("tukey" , 256 , False )
116- window = window .astype ("float64 " ).reshape ((- 1 , 1 , 1 ))
111+ window = window .astype ("float32 " ).reshape ((- 1 , 1 , 1 ))
117112 window = window / np .sqrt (np .sum (window ** 2 ))
118- values = backend .convert_to_numpy (initializer (shape , "float64 " ))
113+ values = backend .convert_to_numpy (initializer (shape , "float32 " ))
119114 self .assertAllClose (np .sin (args ) * window , values , ** tol_kwargs )
120115 self .run_class_serialization_test (initializer )
121116
@@ -125,9 +120,9 @@ def test_stft_initializer(self):
125120 "spectrum" ,
126121 )
127122 window = np .arange (1 , 257 )
128- window = window .astype ("float64 " ).reshape ((- 1 , 1 , 1 ))
123+ window = window .astype ("float32 " ).reshape ((- 1 , 1 , 1 ))
129124 window = window / np .sum (window )
130- values = backend .convert_to_numpy (initializer (shape , "float64 " ))
125+ values = backend .convert_to_numpy (initializer (shape , "float32 " ))
131126 self .assertAllClose (np .sin (args ) * window , values , ** tol_kwargs )
132127 self .run_class_serialization_test (initializer )
133128
0 commit comments