@@ -5392,7 +5392,6 @@ def func(input_val):
5392
5392
self .config .opset = current_opset
5393
5393
5394
5394
@check_tf_min_version ("1.14" )
5395
- #@skip_tflite("FlexRFFT2D")
5396
5395
def test_rfft_ops (self ):
5397
5396
5398
5397
def dft_slow (x , M , fft_length ):
@@ -5434,7 +5433,6 @@ def func3(x):
5434
5433
self ._run_test_case (func3 , [_OUTPUT ], {_INPUT : x_val })
5435
5434
5436
5435
@check_tf_min_version ("1.14" )
5437
- #@skip_tflite("FlexRFFT2D")
5438
5436
@skip_tfjs ("TFJS executes rfft with poor accuracy" )
5439
5437
@check_opset_min_version (10 , "Slice" )
5440
5438
def test_rfft_ops_fft_length (self ):
@@ -5446,7 +5444,6 @@ def func1_length(x):
5446
5444
self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5447
5445
5448
5446
@check_tf_min_version ("1.14" )
5449
- #@skip_tflite("FlexRFFT2D")
5450
5447
@skip_tfjs ("TFJS executes rfft with poor accuracy" )
5451
5448
@check_opset_min_version (10 , "Slice" )
5452
5449
def test_rfft_ops_fft_length_many (self ):
@@ -5461,7 +5458,6 @@ def func1_length(x):
5461
5458
self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5462
5459
5463
5460
@check_tf_min_version ("1.14" )
5464
- #@skip_tflite("FlexRFFT2D")
5465
5461
@check_opset_min_version (10 , "Slice" )
5466
5462
def test_rfft_ops_fft_length_many_bigger (self ):
5467
5463
for i in range (4 , 7 ):
@@ -5491,8 +5487,7 @@ def func1_length(x):
5491
5487
self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5492
5488
5493
5489
@check_tf_min_version ("1.14" )
5494
- #@skip_tflite("FlexRFFT2D")
5495
- @check_opset_min_version (10 , "Slice" )
5490
+ @check_opset_min_version (11 , "CumSum" )
5496
5491
def test_rfft2d_ops (self ):
5497
5492
5498
5493
x_val = make_xval ([3 , 4 ]).astype (np .float32 )
@@ -5516,34 +5511,35 @@ def func3(x):
5516
5511
self ._run_test_case (func3 , [_OUTPUT ], {_INPUT : x_val })
5517
5512
5518
5513
@check_tf_min_version ("1.14" )
5519
- #@skip_tflite("FlexRFFT2D")
5520
- @check_opset_min_version (10 , "Slice" )
5514
+ @check_opset_min_version (11 , "CumSum" )
5521
5515
def test_rfft2d_ops_fft_length (self ):
5522
5516
5523
5517
x_val = make_xval ([3 , 4 ]).astype (np .float32 )
5524
5518
def func1_length (x ):
5525
5519
op_ = tf .signal .rfft2d (x , np .array ([3 , 3 ], dtype = np .int32 ))
5526
5520
return tf .abs (op_ , name = _TFOUTPUT )
5527
- self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val }, optimize = False )
5528
- self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5521
+ with self .subTest (optimize = False ):
5522
+ self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val }, optimize = False )
5523
+ with self .subTest (optimize = True ):
5524
+ self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5529
5525
5530
5526
@check_tf_min_version ("1.14" )
5531
- #@skip_tflite("FlexRFFT2D")
5532
- @check_opset_min_version (10 , "Slice" )
5527
+ @check_opset_min_version (11 , "CumSum" )
5533
5528
def test_rfft2d_ops_fft_length_many (self ):
5534
5529
for i in range (7 , 4 , - 1 ):
5535
5530
for j in range (7 , 4 , - 1 ):
5536
5531
for m in range (0 , 3 ):
5537
5532
for n in range (0 , 3 ):
5538
- with self .subTest (shape = (i , j ), fft_length = (m , n )):
5539
- x_val = make_xval ([i , j ]).astype (np .float32 ) / 100
5540
- def func1_length (x ):
5541
- op_ = tf .signal .rfft2d (x , np .array ([i - m , j - n ], dtype = np .int32 ))
5542
- return tf .abs (op_ , name = _TFOUTPUT )
5543
- self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5533
+ for opt in [False , True ]:
5534
+ with self .subTest (shape = (i , j ), fft_length = (m , n ), optimize = opt ):
5535
+ x_val = make_xval ([i , j ]).astype (np .float32 ) / 100
5536
+ def func1_length (x ):
5537
+ op_ = tf .signal .rfft2d (x , np .array ([i - m , j - n ], dtype = np .int32 ))
5538
+ return tf .abs (op_ , name = _TFOUTPUT )
5539
+ self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val }, optimize = opt )
5544
5540
5545
5541
@check_tf_min_version ("1.14" )
5546
- @check_opset_min_version (10 , "Slice " )
5542
+ @check_opset_min_version (11 , "CumSum " )
5547
5543
@unittest .skipIf (True , reason = "Not fully implemented for dynamic shape." )
5548
5544
def test_fft_ops (self ):
5549
5545
x_val = make_xval ([3 , 4 ]).astype (np .float32 )
@@ -5566,6 +5562,37 @@ def func(x):
5566
5562
x_val = np .array ([1 , 5 , 2 , 0 , 3 , 4 ], dtype = np .int64 )
5567
5563
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
5568
5564
5565
+ @check_tf_min_version ("1.14" )
5566
+ @check_opset_min_version (11 , "CumSum" )
5567
+ def test_rfft2d_ops_specific_dimension (self ):
5568
+
5569
+ x_val = make_xval ([3 , 1 , 4 ]).astype (np .float32 )
5570
+
5571
+ def func1 (x ):
5572
+ op_ = tf .signal .rfft2d (x , np .array ([1 , 4 ], dtype = np .int32 ))
5573
+ return tf .abs (op_ , name = _TFOUTPUT )
5574
+ with self .subTest (shape = (3 , 1 , 4 ), fft_length = (1 , 4 ), optimize = False ):
5575
+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val }, optimize = False )
5576
+ with self .subTest (shape = (3 , 1 , 4 ), fft_length = (1 , 4 ), optimize = True ):
5577
+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val })
5578
+
5579
+ for shape in [(3 , 1 , 4 ), (5 , 7 ), (3 , 5 , 7 ), (7 , 5 )]:
5580
+ for fft_length in [shape [- 2 :], (1 , shape [- 1 ]),
5581
+ (min (2 , shape [- 2 ]), shape [- 1 ]),
5582
+ (shape [- 2 ], 2 ),
5583
+ (min (3 , shape [- 2 ]), min (4 , shape [- 2 ]))]:
5584
+ if fft_length == (1 , 1 ):
5585
+ # The code fails in this case but that's unlikely to happen.
5586
+ continue
5587
+ for optimize in [False , True ]:
5588
+ with self .subTest (shape = shape , fft_length = fft_length , optimize = optimize ):
5589
+ x_val = make_xval (list (shape )).astype (np .float32 )
5590
+ x_val /= x_val .size
5591
+ def func1 (x ):
5592
+ op_ = tf .signal .rfft2d (x , np .array (fft_length , dtype = np .int32 ))
5593
+ return tf .abs (op_ , name = _TFOUTPUT )
5594
+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val }, optimize = optimize )
5595
+
5569
5596
@check_tf_min_version ("2.1" )
5570
5597
@skip_tflite ("TFlite errors on some attributes" )
5571
5598
@check_opset_min_version (9 , "string" )
0 commit comments