@@ -1548,6 +1548,28 @@ def test_reverse_sequence_time_major(self):
1548
1548
1549
1549
@check_opset_min_version (8 , "where" )
1550
1550
def test_where (self ):
1551
+ x_val = np .array ([1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ], dtype = np .float32 )
1552
+ true_result = np .array ([111 , 222 , 333 , 444 , 555 , 666 , 777 , 888 , 999 , 1000 ],
1553
+ dtype = np .float32 )
1554
+ false_result = np .array ([- 111 , - 222 , - 333 , - 444 , - 555 , - 666 , - 777 , - 888 , - 999 , - 1000 ],
1555
+ dtype = np .float32 )
1556
+ x = tf .placeholder (tf .float32 , x_val .shape , name = _TFINPUT )
1557
+ picks = tf .where (x > - 1 , true_result , false_result )
1558
+ _ = tf .identity (picks , name = _TFOUTPUT )
1559
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1560
+
1561
+ tf .reset_default_graph ()
1562
+ x_val = np .array (1 , dtype = np .float32 )
1563
+ true_result = np .array (100 , dtype = np .float32 )
1564
+ false_result = np .array (- 111 , dtype = np .float32 )
1565
+ x = tf .placeholder (tf .float32 , x_val .shape , name = _TFINPUT )
1566
+ picks = tf .where (x > - 1 , true_result , false_result )
1567
+ _ = tf .identity (picks , name = _TFOUTPUT )
1568
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1569
+
1570
+ @check_opset_min_version (8 , "where" )
1571
+ @check_target ("rs6" , "onnxruntime Where type limitation" )
1572
+ def test_where_int32 (self ):
1551
1573
x_val = np .array ([1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ], dtype = np .int32 )
1552
1574
true_result = np .array ([111 , 222 , 333 , 444 , 555 , 666 , 777 , 888 , 999 , 1000 ],
1553
1575
dtype = np .int32 )
@@ -1560,59 +1582,59 @@ def test_where(self):
1560
1582
1561
1583
@check_opset_min_version (8 , "where" )
1562
1584
def test_where_with_two_rank_input (self ):
1563
- x_val = np .array ([1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ], dtype = np .int32 )
1585
+ x_val = np .array ([1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ], dtype = np .float32 )
1564
1586
true_result = np .array ([[111 , 111 ], [222 , 222 ], [333 , 333 ], [444 , 444 ], [555 , 555 ],
1565
1587
[666 , 666 ], [777 , 777 ], [888 , 888 ], [999 , 999 ], [1000 , 1000 ]],
1566
- dtype = np .int32 )
1588
+ dtype = np .float32 )
1567
1589
false_result = np .array ([[- 111 , - 111 ], [- 222 , - 222 ], [- 333 , - 333 ], [- 444 , - 444 ],
1568
1590
[- 555 , - 555 ], [- 666 , - 666 ], [- 777 , - 777 ], [- 888 , - 888 ],
1569
1591
[- 999 , - 999 ], [- 1000 , - 1000 ]],
1570
- dtype = np .int32 )
1571
- x = tf .placeholder (tf .int32 , [None ], name = _TFINPUT )
1592
+ dtype = np .float32 )
1593
+ x = tf .placeholder (tf .float32 , [None ], name = _TFINPUT )
1572
1594
picks = tf .where (tf .greater_equal (x , 0 ), true_result , false_result )
1573
1595
_ = tf .identity (picks , name = _TFOUTPUT )
1574
1596
1575
1597
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1576
1598
1577
1599
@check_opset_min_version (8 , "where" )
1578
1600
def test_where_with_two_rank_condition (self ):
1579
- x_val = np .array ([[1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ]], dtype = np .int32 )
1601
+ x_val = np .array ([[1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ]], dtype = np .float32 )
1580
1602
true_result = np .array ([[111 , 222 , 333 , 444 , 555 , 666 , 777 , 888 , 999 , 1000 ]],
1581
- dtype = np .int32 )
1603
+ dtype = np .float32 )
1582
1604
false_result = np .array ([[- 111 , - 222 , - 333 , - 444 , - 555 , - 666 , - 777 , - 888 , - 999 , - 1000 ]],
1583
- dtype = np .int32 )
1584
- x = tf .placeholder (tf .int32 , [1 , 10 ], name = _TFINPUT )
1605
+ dtype = np .float32 )
1606
+ x = tf .placeholder (tf .float32 , [1 , 10 ], name = _TFINPUT )
1585
1607
picks = tf .where (tf .greater_equal (x , 0 ), true_result , false_result )
1586
1608
_ = tf .identity (picks , name = _TFOUTPUT )
1587
1609
1588
1610
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1589
1611
1590
1612
@check_opset_min_version (8 , "where" )
1591
1613
def test_where_with_three_rank_condition (self ):
1592
- x_val = np .array ([[[1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ]]], dtype = np .int32 )
1614
+ x_val = np .array ([[[1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ]]], dtype = np .float32 )
1593
1615
true_result = np .array ([[[111 , 222 , 333 , 444 , 555 , 666 , 777 , 888 , 999 , 1000 ]]],
1594
- dtype = np .int32 )
1616
+ dtype = np .float32 )
1595
1617
false_result = np .array ([[[- 111 , - 222 , - 333 , - 444 , - 555 , - 666 , - 777 , - 888 , - 999 , - 1000 ]]],
1596
- dtype = np .int32 )
1597
- x = tf .placeholder (tf .int32 , [1 , 1 , 10 ], name = _TFINPUT )
1618
+ dtype = np .float32 )
1619
+ x = tf .placeholder (tf .float32 , [1 , 1 , 10 ], name = _TFINPUT )
1598
1620
picks = tf .where (tf .greater_equal (x , 0 ), true_result , false_result )
1599
1621
_ = tf .identity (picks , name = _TFOUTPUT )
1600
1622
1601
1623
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1602
1624
1603
1625
@check_opset_min_version (8 , "where" )
1604
1626
def test_where_scalar (self ):
1605
- x_val = np .array (6 , dtype = np .int32 )
1627
+ x_val = np .array (6 , dtype = np .float32 )
1606
1628
true_result = np .array ([111 , 222 , 333 , 444 , 555 , 666 , 777 , 888 , 999 , 1000 ],
1607
- dtype = np .int32 )
1629
+ dtype = np .float32 )
1608
1630
false_result = np .array ([- 111 , - 222 , - 333 , - 444 , - 555 , - 666 , - 777 , - 888 , - 999 , - 1000 ],
1609
- dtype = np .int32 )
1610
- x = tf .placeholder (tf .int32 , [], name = _TFINPUT )
1631
+ dtype = np .float32 )
1632
+ x = tf .placeholder (tf .float32 , [], name = _TFINPUT )
1611
1633
picks = tf .where (tf .greater_equal (x , 0 ), true_result , false_result )
1612
1634
_ = tf .identity (picks , name = _TFOUTPUT )
1613
1635
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1614
1636
1615
- @check_opset_min_version (9 , "where " )
1637
+ @check_opset_min_version (9 , "NonZero " )
1616
1638
@check_target ("rs6" , "onnxruntime Transpose type limitation" )
1617
1639
def test_where_with_cond_only (self ):
1618
1640
for np_type , tf_type in [(np .int32 , tf .int32 ), (np .float32 , tf .float32 )]:
0 commit comments