@@ -5493,8 +5493,10 @@ def test_rfft2d_ops_fft_length(self):
5493
5493
def func1_length (x ):
5494
5494
op_ = tf .signal .rfft2d (x , np .array ([3 , 3 ], dtype = np .int32 ))
5495
5495
return tf .abs (op_ , name = _TFOUTPUT )
5496
- self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val }, optimize = False )
5497
- self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5496
+ with self .subTest (optimize = False ):
5497
+ self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val }, optimize = False )
5498
+ with self .subTest (optimize = True ):
5499
+ self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5498
5500
5499
5501
@check_tf_min_version ("1.14" )
5500
5502
@check_opset_min_version (10 , "Slice" )
@@ -5503,12 +5505,13 @@ def test_rfft2d_ops_fft_length_many(self):
5503
5505
for j in range (7 , 4 , - 1 ):
5504
5506
for m in range (0 , 3 ):
5505
5507
for n in range (0 , 3 ):
5506
- with self .subTest (shape = (i , j ), fft_length = (m , n )):
5507
- x_val = make_xval ([i , j ]).astype (np .float32 ) / 100
5508
- def func1_length (x ):
5509
- op_ = tf .signal .rfft2d (x , np .array ([i - m , j - n ], dtype = np .int32 ))
5510
- return tf .abs (op_ , name = _TFOUTPUT )
5511
- self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5508
+ for opt in [False , True ]:
5509
+ with self .subTest (shape = (i , j ), fft_length = (m , n ), optimize = opt ):
5510
+ x_val = make_xval ([i , j ]).astype (np .float32 ) / 100
5511
+ def func1_length (x ):
5512
+ op_ = tf .signal .rfft2d (x , np .array ([i - m , j - n ], dtype = np .int32 ))
5513
+ return tf .abs (op_ , name = _TFOUTPUT )
5514
+ self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val }, optimize = opt )
5512
5515
5513
5516
@check_tf_min_version ("1.14" )
5514
5517
@check_opset_min_version (10 , "Slice" )
@@ -5543,21 +5546,23 @@ def test_rfft2d_ops_specific_dimension(self):
5543
5546
def func1 (x ):
5544
5547
op_ = tf .signal .rfft2d (x , np .array ([1 , 4 ], dtype = np .int32 ))
5545
5548
return tf .abs (op_ , name = _TFOUTPUT )
5546
- self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val }, optimize = False )
5547
- self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val })
5549
+ with self .subTest (shape = (3 , 1 , 4 ), fft_length = (1 , 4 ), optimize = False ):
5550
+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val }, optimize = False )
5551
+ with self .subTest (shape = (3 , 1 , 4 ), fft_length = (1 , 4 ), optimize = True ):
5552
+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val })
5548
5553
5549
5554
for shape in [(3 , 1 , 4 ), (5 , 7 ), (3 , 5 , 7 ), (7 , 5 )]:
5550
5555
for fft_length in [shape [- 2 :], (1 , shape [- 1 ]),
5551
5556
(min (2 , shape [- 2 ]), shape [- 1 ]),
5552
5557
(shape [- 2 ], 2 ),
5553
5558
(min (3 , shape [- 2 ]), min (4 , shape [- 2 ]))]:
5554
- with self . subTests ( shape = shape , fft_length = fft_length ) :
5555
- x_val = make_xval ( list ( shape )). astype ( np . float32 )
5556
- def func1 ( x ):
5557
- op_ = tf . signal . rfft2d ( x , np . array ( fft_length , dtype = np . int32 ))
5558
- return tf .abs ( op_ , name = _TFOUTPUT )
5559
- self . _run_test_case ( func1 , [ _OUTPUT ], { _INPUT : x_val }, optimize = False )
5560
- self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val })
5559
+ for optimize in [ False , True ] :
5560
+ with self . subTest ( shape = shape , fft_length = fft_length , optimize = optimize ):
5561
+ x_val = make_xval ( list ( shape )). astype ( np . float32 )
5562
+ def func1 ( x ):
5563
+ op_ = tf .signal . rfft2d ( x , np . array ( fft_length , dtype = np . int32 ) )
5564
+ return tf . abs ( op_ , name = _TFOUTPUT )
5565
+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val }, optimize = optimize )
5561
5566
5562
5567
@check_tf_min_version ("2.1" )
5563
5568
@skip_tflite ("TFlite errors on some attributes" )
0 commit comments