@@ -5367,7 +5367,6 @@ def func(input_val):
5367
5367
self .config .opset = current_opset
5368
5368
5369
5369
@check_tf_min_version ("1.14" )
5370
- #@skip_tflite("FlexRFFT2D")
5371
5370
def test_rfft_ops (self ):
5372
5371
5373
5372
def dft_slow (x , M , fft_length ):
@@ -5409,7 +5408,6 @@ def func3(x):
5409
5408
self ._run_test_case (func3 , [_OUTPUT ], {_INPUT : x_val })
5410
5409
5411
5410
@check_tf_min_version ("1.14" )
5412
- #@skip_tflite("FlexRFFT2D")
5413
5411
@skip_tfjs ("TFJS executes rfft with poor accuracy" )
5414
5412
@check_opset_min_version (10 , "Slice" )
5415
5413
def test_rfft_ops_fft_length (self ):
@@ -5421,7 +5419,6 @@ def func1_length(x):
5421
5419
self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5422
5420
5423
5421
@check_tf_min_version ("1.14" )
5424
- #@skip_tflite("FlexRFFT2D")
5425
5422
@skip_tfjs ("TFJS executes rfft with poor accuracy" )
5426
5423
@check_opset_min_version (10 , "Slice" )
5427
5424
def test_rfft_ops_fft_length_many (self ):
@@ -5436,7 +5433,6 @@ def func1_length(x):
5436
5433
self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5437
5434
5438
5435
@check_tf_min_version ("1.14" )
5439
- #@skip_tflite("FlexRFFT2D")
5440
5436
@check_opset_min_version (10 , "Slice" )
5441
5437
def test_rfft_ops_fft_length_many_bigger (self ):
5442
5438
for i in range (4 , 7 ):
@@ -5466,7 +5462,6 @@ def func1_length(x):
5466
5462
self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5467
5463
5468
5464
@check_tf_min_version ("1.14" )
5469
- #@skip_tflite("FlexRFFT2D")
5470
5465
@check_opset_min_version (10 , "Slice" )
5471
5466
def test_rfft2d_ops (self ):
5472
5467
@@ -5491,7 +5486,6 @@ def func3(x):
5491
5486
self ._run_test_case (func3 , [_OUTPUT ], {_INPUT : x_val })
5492
5487
5493
5488
@check_tf_min_version ("1.14" )
5494
- #@skip_tflite("FlexRFFT2D")
5495
5489
@check_opset_min_version (10 , "Slice" )
5496
5490
def test_rfft2d_ops_fft_length (self ):
5497
5491
@@ -5503,7 +5497,6 @@ def func1_length(x):
5503
5497
self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5504
5498
5505
5499
@check_tf_min_version ("1.14" )
5506
- #@skip_tflite("FlexRFFT2D")
5507
5500
@check_opset_min_version (10 , "Slice" )
5508
5501
def test_rfft2d_ops_fft_length_many (self ):
5509
5502
for i in range (7 , 4 , - 1 ):
@@ -5541,7 +5534,34 @@ def func(x):
5541
5534
x_val = np .array ([1 , 5 , 2 , 0 , 3 , 4 ], dtype = np .int64 )
5542
5535
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
5543
5536
5537
+ @check_tf_min_version ("1.14" )
5538
+ @check_opset_min_version (10 , "Slice" )
5539
+ def test_rfft2d_ops_specific_dimension (self ):
5540
+
5541
+ x_val = make_xval ([3 , 1 , 4 ]).astype (np .float32 )
5542
+
5543
+ def func1 (x ):
5544
+ op_ = tf .signal .rfft2d (x , np .array ([1 , 4 ], dtype = np .int32 ))
5545
+ return tf .abs (op_ , name = _TFOUTPUT )
5546
+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val }, optimize = False )
5547
+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val })
5548
+
5549
+ for shape in [(3 , 1 , 4 ), (5 , 7 ), (3 , 5 , 7 ), (7 , 5 )]:
5550
+ for fft_length in [shape [- 2 :], (1 , shape [- 1 ]),
5551
+ (min (2 , shape [- 2 ]), shape [- 1 ]),
5552
+ (shape [- 2 ], 2 ),
5553
+ (min (3 , shape [- 2 ]), min (4 , shape [- 2 ]))]:
5554
+ with self .subTests (shape = shape , fft_length = fft_length ):
5555
+ x_val = make_xval (list (shape )).astype (np .float32 )
5556
+ def func1 (x ):
5557
+ op_ = tf .signal .rfft2d (x , np .array (fft_length , dtype = np .int32 ))
5558
+ return tf .abs (op_ , name = _TFOUTPUT )
5559
+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val }, optimize = False )
5560
+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val })
5544
5561
5545
5562
5546
5563
if __name__ == '__main__' :
5564
+ cl = BackendTests ()
5565
+ cl .setUp ()
5566
+ cl .test_rfft2d_ops_specific_dimension ()
5547
5567
unittest_main ()
0 commit comments