@@ -1539,17 +1539,19 @@ def test_erf(self):
1539
1539
_ = tf .identity (x_ , name = _TFOUTPUT )
1540
1540
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, rtol = 0.01 )
1541
1541
1542
- @check_opset_min_version (8 , "Scan" )
1543
- @skip_opset (9 , "ReverseSequence" )
1544
- def test_reverse_sequence_batch_major (self ):
1542
+ def _test_reverse_sequence_batch_major (self , extra_opset = None ):
1543
+ process_args = {}
1544
+ if extra_opset is not None :
1545
+ process_args ["extra_opset" ] = [extra_opset ]
1546
+
1545
1547
x_val = np .array ([[[1 , 2 , 3 ], [4 , 5 , 6 ], [0 , 0 , 0 ]],
1546
1548
[[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]],
1547
1549
[[1 , 2 , 3 ], [0 , 0 , 0 ], [0 , 0 , 0 ]]],
1548
1550
dtype = np .float32 )
1549
1551
x = tf .placeholder (tf .float32 , [None , 3 , 3 ], name = _TFINPUT )
1550
1552
x_ = tf .reverse_sequence (x , seq_axis = 1 , batch_axis = 0 , seq_lengths = [2 , 3 , 1 ])
1551
1553
_ = tf .identity (x_ , name = _TFOUTPUT )
1552
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1554
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, process_args = process_args )
1553
1555
tf .reset_default_graph ()
1554
1556
1555
1557
x_val = np .array ([[1 , 2 , 3 ], [1 , 2 , 3 ], [1 , 2 , 3 ],
@@ -1560,19 +1562,21 @@ def test_reverse_sequence_batch_major(self):
1560
1562
x = tf .placeholder (tf .float32 , [None , 3 ], name = _TFINPUT )
1561
1563
x_ = tf .reverse_sequence (x , seq_axis = 1 , batch_axis = 0 , seq_lengths = [3 ] * 9 )
1562
1564
_ = tf .identity (x_ , name = _TFOUTPUT )
1563
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1565
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, process_args = process_args )
1564
1566
tf .reset_default_graph ()
1565
1567
1566
1568
x_val_shape = [5 , 5 , 7 , 8 , 9 ]
1567
1569
x_val = np .random .randint (0 , 100 , x_val_shape ).astype (np .float32 )
1568
1570
x = tf .placeholder (tf .float32 , [None , 5 , 7 , 8 , 9 ], name = _TFINPUT )
1569
1571
x_ = tf .reverse_sequence (x , seq_axis = 1 , batch_axis = 0 , seq_lengths = [5 , 5 , 5 , 5 , 5 ])
1570
1572
_ = tf .identity (x_ , name = _TFOUTPUT )
1571
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1573
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, process_args = process_args )
1574
+
1575
+ def _test_reverse_sequence_time_major (self , extra_opset = None ):
1576
+ process_args = {}
1577
+ if extra_opset is not None :
1578
+ process_args ["extra_opset" ] = [extra_opset ]
1572
1579
1573
- @check_opset_min_version (8 , "Scan" )
1574
- @skip_opset (9 , "ReverseSequence" )
1575
- def test_reverse_sequence_time_major (self ):
1576
1580
x_val = np .array ([[[1 , 2 , 3 ], [1 , 2 , 3 ], [1 , 2 , 3 ]],
1577
1581
[[4 , 5 , 6 ], [4 , 5 , 6 ], [0 , 0 , 0 ]],
1578
1582
[[0 , 0 , 0 ], [7 , 8 , 9 ], [0 , 0 , 0 ]]
@@ -1581,7 +1585,7 @@ def test_reverse_sequence_time_major(self):
1581
1585
x = tf .placeholder (tf .float32 , [3 , None , 3 ], name = _TFINPUT )
1582
1586
x_ = tf .reverse_sequence (x , seq_axis = 0 , batch_axis = 1 , seq_lengths = [2 , 3 , 1 ])
1583
1587
_ = tf .identity (x_ , name = _TFOUTPUT )
1584
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1588
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, process_args = process_args )
1585
1589
tf .reset_default_graph ()
1586
1590
1587
1591
x_val = np .array ([[1 , 2 , 3 ], [1 , 2 , 3 ], [1 , 2 , 3 ],
@@ -1592,15 +1596,36 @@ def test_reverse_sequence_time_major(self):
1592
1596
x = tf .placeholder (tf .float32 , [9 , None ], name = _TFINPUT )
1593
1597
x_ = tf .reverse_sequence (x , seq_axis = 0 , batch_axis = 1 , seq_lengths = [9 , 9 , 9 ])
1594
1598
_ = tf .identity (x_ , name = _TFOUTPUT )
1595
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1599
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, process_args = process_args )
1596
1600
tf .reset_default_graph ()
1597
1601
1598
1602
x_val_shape = [5 , 5 , 7 , 8 , 9 ]
1599
1603
x_val = np .random .randint (0 , 100 , x_val_shape ).astype (np .float32 )
1600
1604
x = tf .placeholder (tf .float32 , [5 , None , 7 , 8 , 9 ], name = _TFINPUT )
1601
1605
x_ = tf .reverse_sequence (x , seq_axis = 0 , batch_axis = 1 , seq_lengths = [5 , 5 , 5 , 5 , 5 ])
1602
1606
_ = tf .identity (x_ , name = _TFOUTPUT )
1603
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1607
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val }, process_args = process_args )
1608
+
1609
+ @check_opset_min_version (8 , "Scan" )
1610
+ @skip_opset (9 , "ReverseSequence" )
1611
+ def test_reverse_sequence_batch_major (self ):
1612
+ self ._test_reverse_sequence_batch_major ()
1613
+
1614
+ @check_opset_min_version (8 , "Scan" )
1615
+ @skip_opset (9 , "ReverseSequence" )
1616
+ def test_reverse_sequence_time_major (self ):
1617
+ self ._test_reverse_sequence_time_major ()
1618
+
1619
+ # only support onnxruntime with version larger than 0.4.0
1620
+ @test_ms_domain ()
1621
+ @check_onnxruntime_min_version ("0.4.0" )
1622
+ def test_ms_reverse_sequence_batch_major (self , extra_opset ):
1623
+ self ._test_reverse_sequence_batch_major (extra_opset )
1624
+
1625
+ @test_ms_domain ()
1626
+ @check_onnxruntime_min_version ("0.4.0" )
1627
+ def test_ms_reverse_sequence_time_major (self , extra_opset ):
1628
+ self ._test_reverse_sequence_time_major (extra_opset )
1604
1629
1605
1630
@check_opset_min_version (8 , "where" )
1606
1631
def test_where (self ):
0 commit comments