@@ -1519,17 +1519,19 @@ def test_erf(self):
1519
1519
_ = tf .identity (x_ , name = _TFOUTPUT )
1520
1520
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, rtol = 0.01 )
1521
1521
1522
- @check_opset_min_version (8 , "Scan" )
1523
- @skip_opset (9 , "ReverseSequence" )
1524
- def test_reverse_sequence_batch_major (self ):
1522
+ def _test_reverse_sequence_batch_major (self , extra_opset = None ):
1523
+ process_args = {}
1524
+ if extra_opset is not None :
1525
+ process_args ["extra_opset" ] = [extra_opset ]
1526
+
1525
1527
x_val = np .array ([[[1 , 2 , 3 ], [4 , 5 , 6 ], [0 , 0 , 0 ]],
1526
1528
[[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]],
1527
1529
[[1 , 2 , 3 ], [0 , 0 , 0 ], [0 , 0 , 0 ]]],
1528
1530
dtype = np .float32 )
1529
1531
x = tf .placeholder (tf .float32 , [None , 3 , 3 ], name = _TFINPUT )
1530
1532
x_ = tf .reverse_sequence (x , seq_axis = 1 , batch_axis = 0 , seq_lengths = [2 , 3 , 1 ])
1531
1533
_ = tf .identity (x_ , name = _TFOUTPUT )
1532
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1534
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, process_args = process_args )
1533
1535
tf .reset_default_graph ()
1534
1536
1535
1537
x_val = np .array ([[1 , 2 , 3 ], [1 , 2 , 3 ], [1 , 2 , 3 ],
@@ -1540,19 +1542,21 @@ def test_reverse_sequence_batch_major(self):
1540
1542
x = tf .placeholder (tf .float32 , [None , 3 ], name = _TFINPUT )
1541
1543
x_ = tf .reverse_sequence (x , seq_axis = 1 , batch_axis = 0 , seq_lengths = [3 ] * 9 )
1542
1544
_ = tf .identity (x_ , name = _TFOUTPUT )
1543
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1545
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, process_args = process_args )
1544
1546
tf .reset_default_graph ()
1545
1547
1546
1548
x_val_shape = [5 , 5 , 7 , 8 , 9 ]
1547
1549
x_val = np .random .randint (0 , 100 , x_val_shape ).astype (np .float32 )
1548
1550
x = tf .placeholder (tf .float32 , [None , 5 , 7 , 8 , 9 ], name = _TFINPUT )
1549
1551
x_ = tf .reverse_sequence (x , seq_axis = 1 , batch_axis = 0 , seq_lengths = [5 , 5 , 5 , 5 , 5 ])
1550
1552
_ = tf .identity (x_ , name = _TFOUTPUT )
1551
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1553
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, process_args = process_args )
1554
+
1555
+ def _test_reverse_sequence_time_major (self , extra_opset = None ):
1556
+ process_args = {}
1557
+ if extra_opset is not None :
1558
+ process_args ["extra_opset" ] = [extra_opset ]
1552
1559
1553
- @check_opset_min_version (8 , "Scan" )
1554
- @skip_opset (9 , "ReverseSequence" )
1555
- def test_reverse_sequence_time_major (self ):
1556
1560
x_val = np .array ([[[1 , 2 , 3 ], [1 , 2 , 3 ], [1 , 2 , 3 ]],
1557
1561
[[4 , 5 , 6 ], [4 , 5 , 6 ], [0 , 0 , 0 ]],
1558
1562
[[0 , 0 , 0 ], [7 , 8 , 9 ], [0 , 0 , 0 ]]
@@ -1561,7 +1565,7 @@ def test_reverse_sequence_time_major(self):
1561
1565
x = tf .placeholder (tf .float32 , [3 , None , 3 ], name = _TFINPUT )
1562
1566
x_ = tf .reverse_sequence (x , seq_axis = 0 , batch_axis = 1 , seq_lengths = [2 , 3 , 1 ])
1563
1567
_ = tf .identity (x_ , name = _TFOUTPUT )
1564
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1568
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, process_args = process_args )
1565
1569
tf .reset_default_graph ()
1566
1570
1567
1571
x_val = np .array ([[1 , 2 , 3 ], [1 , 2 , 3 ], [1 , 2 , 3 ],
@@ -1572,15 +1576,35 @@ def test_reverse_sequence_time_major(self):
1572
1576
x = tf .placeholder (tf .float32 , [9 , None ], name = _TFINPUT )
1573
1577
x_ = tf .reverse_sequence (x , seq_axis = 0 , batch_axis = 1 , seq_lengths = [9 , 9 , 9 ])
1574
1578
_ = tf .identity (x_ , name = _TFOUTPUT )
1575
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1579
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, process_args = process_args )
1576
1580
tf .reset_default_graph ()
1577
1581
1578
1582
x_val_shape = [5 , 5 , 7 , 8 , 9 ]
1579
1583
x_val = np .random .randint (0 , 100 , x_val_shape ).astype (np .float32 )
1580
1584
x = tf .placeholder (tf .float32 , [5 , None , 7 , 8 , 9 ], name = _TFINPUT )
1581
1585
x_ = tf .reverse_sequence (x , seq_axis = 0 , batch_axis = 1 , seq_lengths = [5 , 5 , 5 , 5 , 5 ])
1582
1586
_ = tf .identity (x_ , name = _TFOUTPUT )
1583
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1587
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, process_args = process_args )
1588
+
1589
+ @check_opset_min_version (8 , "Scan" )
1590
+ @skip_opset (9 , "ReverseSequence" )
1591
+ def test_reverse_sequence_batch_major (self ):
1592
+ self ._test_reverse_sequence_batch_major ()
1593
+
1594
+ @check_opset_min_version (8 , "Scan" )
1595
+ @skip_opset (9 , "ReverseSequence" )
1596
+ def test_reverse_sequence_time_major (self ):
1597
+ self ._test_reverse_sequence_time_major ()
1598
+
1599
+ @test_ms_domain ()
1600
+ @unittest .skipIf (True , "not support in pypi onnxruntime" )
1601
+ def test_ms_reverse_sequence_batch_major (self , extra_opset ):
1602
+ self ._test_reverse_sequence_batch_major (extra_opset )
1603
+
1604
+ @test_ms_domain ()
1605
+ @unittest .skipIf (True , "not support in pypi onnxruntime" )
1606
+ def test_ms_reverse_sequence_time_major (self , extra_opset ):
1607
+ self ._test_reverse_sequence_time_major (extra_opset )
1584
1608
1585
1609
@check_opset_min_version (8 , "where" )
1586
1610
def test_where (self ):
0 commit comments