@@ -5367,21 +5367,29 @@ def func(input_val):
5367
5367
self .config .opset = current_opset
5368
5368
5369
5369
@check_tf_min_version ("1.14" )
5370
- @skip_tflite ("FlexRFFT2D" )
5370
+ # @skip_tflite("FlexRFFT2D")
5371
5371
def test_rfft_ops (self ):
5372
5372
5373
- def dft_slow (x , M ):
5374
- xt = x .T
5375
- res = np .dot (M , xt )
5373
+ def dft_slow (x , M , fft_length ):
5374
+ xt = x [:, :fft_length ].T
5375
+ size = fft_length // 2 + 1
5376
+ res = np .dot (M [:, :, :fft_length ], xt )[:, :size , :]
5376
5377
return np .transpose (res , (0 , 2 , 1 ))
5377
5378
5378
5379
x_val = make_xval ([2 , 4 ]).astype (np .float32 )
5379
5380
M_both = make_dft_constant (x_val .shape [1 ], x_val .dtype , x_val .shape [1 ])
5380
- fft = dft_slow (x_val , M_both )
5381
+ fft = dft_slow (x_val , M_both , x_val . shape [ 1 ] )
5381
5382
fft_npy = np .fft .rfft (x_val )
5382
5383
assert_almost_equal (fft [0 , :, :], np .real (fft_npy ))
5383
5384
assert_almost_equal (fft [1 , :, :], np .imag (fft_npy ))
5384
5385
5386
+ x_val = make_xval ([2 , 4 ]).astype (np .float32 )
5387
+ M_both = make_dft_constant (x_val .shape [1 ], x_val .dtype , x_val .shape [1 ]- 1 )
5388
+ fft = dft_slow (x_val , M_both , x_val .shape [1 ]- 1 )
5389
+ fft_npy = np .fft .rfft (x_val , x_val .shape [1 ]- 1 )
5390
+ assert_almost_equal (fft [0 , :, :], np .real (fft_npy ))
5391
+ assert_almost_equal (fft [1 , :, :], np .imag (fft_npy ))
5392
+
5385
5393
x_val = make_xval ([3 , 4 ]).astype (np .float32 )
5386
5394
def func1 (x ):
5387
5395
op_ = tf .signal .rfft (x )
@@ -5401,7 +5409,117 @@ def func3(x):
5401
5409
self ._run_test_case (func3 , [_OUTPUT ], {_INPUT : x_val })
5402
5410
5403
5411
@check_tf_min_version ("1.14" )
5404
- @check_opset_min_version (11 , "range" )
5412
+ #@skip_tflite("FlexRFFT2D")
5413
+ @skip_tfjs ("TFJS executes rfft with poor accuracy" )
5414
+ @check_opset_min_version (10 , "Slice" )
5415
+ def test_rfft_ops_fft_length (self ):
5416
+
5417
+ x_val = make_xval ([3 , 9 ]).astype (np .float32 )
5418
+ def func1_length (x ):
5419
+ op_ = tf .signal .rfft (x , np .array ([8 ], dtype = np .int32 ))
5420
+ return tf .abs (op_ , name = _TFOUTPUT )
5421
+ self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5422
+
5423
+ @check_tf_min_version ("1.14" )
5424
+ #@skip_tflite("FlexRFFT2D")
5425
+ @skip_tfjs ("TFJS executes rfft with poor accuracy" )
5426
+ @check_opset_min_version (10 , "Slice" )
5427
+ def test_rfft_ops_fft_length_many (self ):
5428
+ for i in range (4 , 7 ):
5429
+ for j in range (4 , 7 ):
5430
+ for m in range (0 , 3 ):
5431
+ with self .subTest (shape = (i , j ), fft_length = j - m ):
5432
+ x_val = make_xval ([i , j ]).astype (np .float32 )
5433
+ def func1_length (x ):
5434
+ op_ = tf .signal .rfft (x , np .array ([j - m ], dtype = np .int32 ))
5435
+ return tf .abs (op_ , name = _TFOUTPUT )
5436
+ self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5437
+
5438
+ @check_tf_min_version ("1.14" )
5439
+ #@skip_tflite("FlexRFFT2D")
5440
+ @check_opset_min_version (10 , "Slice" )
5441
+ def test_rfft_ops_fft_length_many_bigger (self ):
5442
+ for i in range (4 , 7 ):
5443
+ for j in range (4 , 7 ):
5444
+ for m in range (0 , 3 ):
5445
+ with self .subTest (shape = (i , j ), fft_length = j + m ):
5446
+ x_val = make_xval ([i , j ]).astype (np .float32 ) / 10
5447
+ def func1_length (x ):
5448
+ op_ = tf .signal .rfft (x , np .array ([j + m ], dtype = np .int32 ))
5449
+ return tf .abs (op_ , name = _TFOUTPUT )
5450
+ self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5451
+
5452
+ @check_tf_min_version ("1.14" )
5453
+ @skip_tflite ("Slight accuracy issues with some shapes" )
5454
+ @skip_tfjs ("TFJS executes rfft with poor accuracy" )
5455
+ @check_opset_min_version (10 , "Slice" )
5456
+ def test_rfft_ops_fft_length_many_larger (self ):
5457
+ for i in range (4 , 7 ):
5458
+ for j in range (4 , 7 ):
5459
+ for m in range (- 3 , 3 ):
5460
+ with self .subTest (shape = (3 , i , j ), fft_length = j + m ):
5461
+ x_val = make_xval ([3 , i , j ]).astype (np .float32 ) / 10
5462
+ def func1_length (x ):
5463
+ op_ = tf .signal .rfft (x , np .array ([j + m ], dtype = np .int32 ))
5464
+ return tf .abs (op_ , name = _TFOUTPUT )
5465
+ self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val }, optimize = False )
5466
+ self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5467
+
5468
+ @check_tf_min_version ("1.14" )
5469
+ #@skip_tflite("FlexRFFT2D")
5470
+ @check_opset_min_version (10 , "Slice" )
5471
+ def test_rfft2d_ops (self ):
5472
+
5473
+ x_val = make_xval ([3 , 4 ]).astype (np .float32 )
5474
+
5475
+ def func1 (x ):
5476
+ op_ = tf .signal .rfft2d (x )
5477
+ return tf .abs (op_ , name = _TFOUTPUT )
5478
+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val }, optimize = False )
5479
+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val })
5480
+
5481
+ def func2 (x ):
5482
+ op_ = tf .signal .rfft2d (x )
5483
+ return tf .cos (op_ , name = _TFOUTPUT )
5484
+ with self .assertRaises (ValueError ):
5485
+ self ._run_test_case (func2 , [_OUTPUT ], {_INPUT : x_val })
5486
+
5487
+ def func3 (x ):
5488
+ op_ = tf .signal .rfft2d (x )
5489
+ return tf .identity (op_ , name = _TFOUTPUT )
5490
+ with self .assertRaises (ValueError ):
5491
+ self ._run_test_case (func3 , [_OUTPUT ], {_INPUT : x_val })
5492
+
5493
+ @check_tf_min_version ("1.14" )
5494
+ #@skip_tflite("FlexRFFT2D")
5495
+ @check_opset_min_version (10 , "Slice" )
5496
+ def test_rfft2d_ops_fft_length (self ):
5497
+
5498
+ x_val = make_xval ([3 , 4 ]).astype (np .float32 )
5499
+ def func1_length (x ):
5500
+ op_ = tf .signal .rfft2d (x , np .array ([3 , 3 ], dtype = np .int32 ))
5501
+ return tf .abs (op_ , name = _TFOUTPUT )
5502
+ self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val }, optimize = False )
5503
+ self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5504
+
5505
+ @check_tf_min_version ("1.14" )
5506
+ #@skip_tflite("FlexRFFT2D")
5507
+ @check_opset_min_version (10 , "Slice" )
5508
+ def test_rfft2d_ops_fft_length_many (self ):
5509
+ for i in range (7 , 4 , - 1 ):
5510
+ for j in range (7 , 4 , - 1 ):
5511
+ for m in range (0 , 3 ):
5512
+ for n in range (0 , 3 ):
5513
+ with self .subTest (shape = (i , j ), fft_length = (m , n )):
5514
+ x_val = make_xval ([i , j ]).astype (np .float32 ) / 100
5515
+ def func1_length (x ):
5516
+ op_ = tf .signal .rfft2d (x , np .array ([i - m , j - n ], dtype = np .int32 ))
5517
+ return tf .abs (op_ , name = _TFOUTPUT )
5518
+ self ._run_test_case (func1_length , [_OUTPUT ], {_INPUT : x_val })
5519
+
5520
+ @check_tf_min_version ("1.14" )
5521
+ @check_opset_min_version (10 , "Slice" )
5522
+ @unittest .skipIf (True , reason = "Not fully implemented for dynamic shape." )
5405
5523
def test_fft_ops (self ):
5406
5524
x_val = make_xval ([3 , 4 ]).astype (np .float32 )
5407
5525
def func1 (x ):
0 commit comments