@@ -652,6 +652,66 @@ def test_transpose_cast_cuda(self):
652652 self ._transpose_cast_cuda (TensorProto .FLOAT )
653653 self ._transpose_cast_cuda (TensorProto .FLOAT16 )
654654
655+ def _replace_zero_cuda (self , itype ):
656+ dtype = np .float32 if itype == TensorProto .FLOAT else np .float16
657+ model1 = helper .make_model (
658+ helper .make_graph (
659+ [
660+ helper .make_node ("Equal" , ["X" , "zero" ], ["cond" ]),
661+ helper .make_node ("Where" , ["cond" , "cst" , "X" ], ["Y" ]),
662+ ],
663+ "nd" ,
664+ [helper .make_tensor_value_info ("X" , itype , [None , None , None ])],
665+ [helper .make_tensor_value_info ("Y" , itype , [None , None , None ])],
666+ [
667+ numpy_helper .from_array (np .array ([0 ], dtype = dtype ), name = "zero" ),
668+ numpy_helper .from_array (np .array ([1.67 ], dtype = dtype ), name = "cst" ),
669+ ],
670+ ),
671+ opset_imports = [helper .make_opsetid ("" , 18 )],
672+ ir_version = 9 ,
673+ )
674+
675+ model2 = helper .make_model (
676+ helper .make_graph (
677+ [
678+ helper .make_node (
679+ "ReplaceZero" ,
680+ ["X" ],
681+ ["Y" ],
682+ by = 1.67 ,
683+ domain = "ai.onnx.contrib" ,
684+ )
685+ ],
686+ "nd" ,
687+ [helper .make_tensor_value_info ("X" , itype , [None , None , None ])],
688+ [helper .make_tensor_value_info ("Y" , itype , [None , None , None ])],
689+ ),
690+ opset_imports = [
691+ helper .make_opsetid ("" , 18 ),
692+ helper .make_opsetid ("ai.onnx.contrib" , 1 ),
693+ ],
694+ ir_version = 9 ,
695+ )
696+
697+ dtype = np .float32 if itype == TensorProto .FLOAT else np .float16
698+ x = (np .arange (18 ) - 4 ).reshape ((3 , 2 , 3 )).astype (dtype )
699+
700+ feeds1 = dict (X = x )
701+ ref = ReferenceEvaluator (model1 )
702+ expected = ref .run (None , feeds1 )[0 ]
703+
704+ opts = _ort .SessionOptions ()
705+ opts .register_custom_ops_library (_get_library_path ())
706+ sess = _ort .InferenceSession (model2 .SerializeToString (), opts , providers = ["CUDAExecutionProvider" ])
707+ got = sess .run (None , feeds1 )[0 ]
708+ assert_allclose (expected , got , atol = 1e-5 )
709+
710+ @unittest .skipIf (not has_cuda (), reason = "cuda not available" )
711+ def test_replace_zero_cuda (self ):
712+ self ._replace_zero_cuda (TensorProto .FLOAT )
713+ self ._replace_zero_cuda (TensorProto .FLOAT16 )
714+
655715
656716if __name__ == "__main__" :
657717 unittest .main (verbosity = 2 )
0 commit comments