@@ -822,3 +822,102 @@ def build_conv_act_pool_model(include_reshape_node=False):
822822 onnx .checker .check_model (model_inferred )
823823
824824 return model_inferred
825+
826+
827+ def build_conv_isinf_model (opset_version = 13 ):
828+ # Define your model inputs and outputs
829+ input_names = ["input_0" ]
830+ output_names = ["output_0" ]
831+ input_shapes = [(6 , 32 , 900 , 256 )]
832+ output_shapes = [(6 , 32 , 900 , 256 )]
833+
834+ inputs = [
835+ helper .make_tensor_value_info (input_name , onnx .TensorProto .FLOAT , input_shape )
836+ for input_name , input_shape in zip (input_names , input_shapes )
837+ ]
838+ outputs = [
839+ helper .make_tensor_value_info (output_name , onnx .TensorProto .FLOAT , output_shape )
840+ for output_name , output_shape in zip (output_names , output_shapes )
841+ ]
842+
843+ # Create the ONNX graph with the nodes
844+ nodes = [
845+ helper .make_node (
846+ op_type = "Conv" ,
847+ inputs = ["input_0" , "weights_1" ],
848+ outputs = ["conv1_conv/Conv2D:0" ],
849+ name = "conv1_conv/Conv2D" ,
850+ dilations = [1 , 1 ],
851+ group = 1 ,
852+ kernel_shape = [3 , 3 ],
853+ pads = [1 , 1 , 1 , 1 ],
854+ strides = [1 , 1 ],
855+ ),
856+ helper .make_node (
857+ op_type = "Cast" ,
858+ inputs = ["conv1_conv/Conv2D:0" ],
859+ outputs = ["cast1_cast/Cast:0" ],
860+ name = "cast1_cast/Cast" ,
861+ to = onnx .TensorProto .DOUBLE ,
862+ ),
863+ helper .make_node (
864+ op_type = "IsInf" ,
865+ inputs = ["cast1_cast/Cast:0" ],
866+ outputs = ["isinf1_isinf/IsInf:0" ],
867+ name = "isinf1_isinf/IsInf" ,
868+ ),
869+ helper .make_node (
870+ op_type = "Greater" ,
871+ inputs = ["conv1_conv/Conv2D:0" , "greater_const1" ],
872+ outputs = ["greater1_greater/Greater:0" ],
873+ name = "greater1_greater/Greater" ,
874+ ),
875+ helper .make_node (
876+ op_type = "And" ,
877+ inputs = ["isinf1_isinf/IsInf:0" , "greater1_greater/Greater:0" ],
878+ outputs = ["and1_and/And:0" ],
879+ name = "and1_and/And" ,
880+ ),
881+ helper .make_node (
882+ op_type = "Where" ,
883+ inputs = ["and1_and/And:0" , "conv1_conv/Conv2D:0" , "where_const1" ],
884+ outputs = ["output_0" ],
885+ name = "where1_where/Where" ,
886+ ),
887+ ]
888+
889+ # Create the ONNX initializers
890+ initializers = [
891+ helper .make_tensor (
892+ name = "weights_1" ,
893+ data_type = onnx .TensorProto .FLOAT ,
894+ dims = (32 , 32 , 3 , 3 ),
895+ vals = np .random .uniform (low = 0.5 , high = 1.0 , size = 32 * 32 * 3 * 3 ),
896+ ),
897+ helper .make_tensor (
898+ name = "greater_const1" ,
899+ data_type = onnx .TensorProto .FLOAT ,
900+ dims = (1 ,),
901+ vals = [0 ],
902+ ),
903+ helper .make_tensor (
904+ name = "where_const1" ,
905+ data_type = onnx .TensorProto .FLOAT ,
906+ dims = (1 ,),
907+ vals = [10000 ],
908+ ),
909+ ]
910+
911+ # Create the ONNX graph with the nodes and initializers
912+ graph = helper .make_graph (nodes , "conv_isinf" , inputs , outputs , initializer = initializers )
913+
914+ # Create the ONNX model
915+ model = helper .make_model (graph )
916+ model .opset_import [0 ].version = opset_version
917+ model .ir_version = 10
918+
919+ # Check the ONNX model
920+ model_inferred = onnx .shape_inference .infer_shapes (model )
921+ onnx .checker .check_model (model_inferred )
922+
923+ return model_inferred
0 commit comments