@@ -80,14 +80,9 @@ def test_stft_initializer(self):
80
80
shape = (256 , 1 , 513 )
81
81
time_range = np .arange (256 ).reshape ((- 1 , 1 , 1 ))
82
82
freq_range = (np .arange (513 ) / 1024.0 ).reshape ((1 , 1 , - 1 ))
83
- pi = np .arccos (np .float64 (- 1 ))
83
+ pi = np .arccos (np .float32 (- 1 ))
84
84
args = - 2 * pi * time_range * freq_range
85
-
86
- tol_kwargs = {}
87
- if backend .backend () == "jax" :
88
- # TODO(mostafa-mahmoud): investigate the cases
89
- # of non-small error in jax and torch
90
- tol_kwargs = {"atol" : 1e-4 , "rtol" : 1e-6 }
85
+ tol_kwargs = {"atol" : 1e-4 , "rtol" : 1e-6 }
91
86
92
87
initializer = initializers .STFT ("real" , None )
93
88
values = backend .convert_to_numpy (initializer (shape ))
@@ -101,8 +96,8 @@ def test_stft_initializer(self):
101
96
True ,
102
97
)
103
98
window = scipy .signal .windows .get_window ("hamming" , 256 , True )
104
- window = window .astype ("float64 " ).reshape ((- 1 , 1 , 1 ))
105
- values = backend .convert_to_numpy (initializer (shape , "float64 " ))
99
+ window = window .astype ("float32 " ).reshape ((- 1 , 1 , 1 ))
100
+ values = backend .convert_to_numpy (initializer (shape , "float32 " ))
106
101
self .assertAllClose (np .cos (args ) * window , values , ** tol_kwargs )
107
102
self .run_class_serialization_test (initializer )
108
103
@@ -113,9 +108,9 @@ def test_stft_initializer(self):
113
108
False ,
114
109
)
115
110
window = scipy .signal .windows .get_window ("tukey" , 256 , False )
116
- window = window .astype ("float64 " ).reshape ((- 1 , 1 , 1 ))
111
+ window = window .astype ("float32 " ).reshape ((- 1 , 1 , 1 ))
117
112
window = window / np .sqrt (np .sum (window ** 2 ))
118
- values = backend .convert_to_numpy (initializer (shape , "float64 " ))
113
+ values = backend .convert_to_numpy (initializer (shape , "float32 " ))
119
114
self .assertAllClose (np .sin (args ) * window , values , ** tol_kwargs )
120
115
self .run_class_serialization_test (initializer )
121
116
@@ -125,9 +120,9 @@ def test_stft_initializer(self):
125
120
"spectrum" ,
126
121
)
127
122
window = np .arange (1 , 257 )
128
- window = window .astype ("float64 " ).reshape ((- 1 , 1 , 1 ))
123
+ window = window .astype ("float32 " ).reshape ((- 1 , 1 , 1 ))
129
124
window = window / np .sum (window )
130
- values = backend .convert_to_numpy (initializer (shape , "float64 " ))
125
+ values = backend .convert_to_numpy (initializer (shape , "float32 " ))
131
126
self .assertAllClose (np .sin (args ) * window , values , ** tol_kwargs )
132
127
self .run_class_serialization_test (initializer )
133
128
0 commit comments