@@ -372,3 +372,185 @@ def build_conv_concat_model():
372372 onnx .checker .check_model (model_inferred )
373373
374374 return model_inferred
375+
376+
377+ def build_convtranspose_conv_residual_model ():
378+ # Define your model inputs and outputs
379+ input_names = ["input_0" ]
380+ output_names = ["output_0" ]
381+ input_shapes = [(2 , 39 , 96 , 192 )]
382+ output_shapes = [(2 , 32 , 192 , 384 )]
383+
384+ inputs = [
385+ helper .make_tensor_value_info (input_name , onnx .TensorProto .FLOAT , input_shape )
386+ for input_name , input_shape in zip (input_names , input_shapes )
387+ ]
388+ outputs = [
389+ helper .make_tensor_value_info (output_name , onnx .TensorProto .FLOAT , output_shape )
390+ for output_name , output_shape in zip (output_names , output_shapes )
391+ ]
392+
393+ # Create the ONNX graph with the nodes
394+ nodes = [
395+ helper .make_node (
396+ op_type = "ConvTranspose" ,
397+ inputs = ["input_0" , "weights_1" , "bias_1" ],
398+ outputs = ["convtranspose1_convtranspose/ConvTranspose:0" ],
399+ name = "convtranspose1_convtranspose/ConvTranspose" ,
400+ dilations = [1 , 1 ],
401+ group = 1 ,
402+ kernel_shape = [2 , 2 ],
403+ pads = [0 , 0 , 0 , 0 ],
404+ strides = [2 , 2 ],
405+ ),
406+ helper .make_node (
407+ op_type = "Relu" ,
408+ inputs = ["convtranspose1_convtranspose/ConvTranspose:0" ],
409+ outputs = ["relu1_relu/Relu:0" ],
410+ name = "relu1_relu/Relu" ,
411+ ),
412+ helper .make_node (
413+ op_type = "Conv" ,
414+ inputs = ["relu1_relu/Relu:0" , "weights_2" ],
415+ outputs = ["conv2_conv/Conv2D:0" ],
416+ name = "conv2_conv/Conv2D" ,
417+ dilations = [1 , 1 ],
418+ group = 1 ,
419+ kernel_shape = [3 , 3 ],
420+ pads = [1 , 1 , 1 , 1 ],
421+ strides = [1 , 1 ],
422+ ),
423+ helper .make_node (
424+ op_type = "BatchNormalization" ,
425+ inputs = ["conv2_conv/Conv2D:0" , "bn1_scale" , "bn1_bias" , "bn1_mean" , "bn1_var" ],
426+ outputs = ["bn1_batchnorm/BatchNormalization:0" ],
427+ name = "bn1_batchnorm/BatchNormalization" ,
428+ ),
429+ helper .make_node (
430+ op_type = "Relu" ,
431+ inputs = ["bn1_batchnorm/BatchNormalization:0" ],
432+ outputs = ["relu2_relu/Relu:0" ],
433+ name = "relu2_relu/Relu" ,
434+ ),
435+ helper .make_node (
436+ op_type = "Conv" ,
437+ inputs = ["relu2_relu/Relu:0" , "weights_3" ],
438+ outputs = ["conv3_conv/Conv2D:0" ],
439+ name = "conv3_conv/Conv2D" ,
440+ dilations = [1 , 1 ],
441+ group = 1 ,
442+ kernel_shape = [3 , 3 ],
443+ pads = [1 , 1 , 1 , 1 ],
444+ strides = [1 , 1 ],
445+ ),
446+ helper .make_node (
447+ op_type = "BatchNormalization" ,
448+ inputs = ["conv3_conv/Conv2D:0" , "bn2_scale" , "bn2_bias" , "bn2_mean" , "bn2_var" ],
449+ outputs = ["bn2_batchnorm/BatchNormalization:0" ],
450+ name = "bn2_batchnorm/BatchNormalization" ,
451+ ),
452+ helper .make_node (
453+ op_type = "Add" ,
454+ inputs = ["relu1_relu/Relu:0" , "bn2_batchnorm/BatchNormalization:0" ],
455+ outputs = ["add1_add/Add:0" ],
456+ name = "add1_add/Add" ,
457+ ),
458+ helper .make_node (
459+ op_type = "Relu" ,
460+ inputs = ["add1_add/Add:0" ],
461+ outputs = ["output_0" ],
462+ name = "relu3_relu/Relu" ,
463+ ),
464+ ]
465+
466+ # Create the ONNX initializers
467+ initializers = [
468+ helper .make_tensor (
469+ name = "weights_1" ,
470+ data_type = onnx .TensorProto .FLOAT ,
471+ dims = (39 , 32 , 2 , 2 ),
472+ vals = np .random .uniform (low = 0.5 , high = 1.0 , size = 39 * 32 * 2 * 2 ),
473+ ),
474+ helper .make_tensor (
475+ name = "bias_1" ,
476+ data_type = onnx .TensorProto .FLOAT ,
477+ dims = (32 ,),
478+ vals = np .random .uniform (low = 0.5 , high = 1.0 , size = 32 ),
479+ ),
480+ helper .make_tensor (
481+ name = "weights_2" ,
482+ data_type = onnx .TensorProto .FLOAT ,
483+ dims = (32 , 32 , 3 , 3 ),
484+ vals = np .random .uniform (low = 0.5 , high = 1.0 , size = 32 * 32 * 3 * 3 ),
485+ ),
486+ helper .make_tensor (
487+ name = "bn1_scale" ,
488+ data_type = onnx .TensorProto .FLOAT ,
489+ dims = (32 ,),
490+ vals = np .random .uniform (low = 0.5 , high = 1.0 , size = 32 ),
491+ ),
492+ helper .make_tensor (
493+ name = "bn1_bias" ,
494+ data_type = onnx .TensorProto .FLOAT ,
495+ dims = (32 ,),
496+ vals = np .random .uniform (low = 0.5 , high = 1.0 , size = 32 ),
497+ ),
498+ helper .make_tensor (
499+ name = "bn1_mean" ,
500+ data_type = onnx .TensorProto .FLOAT ,
501+ dims = (32 ,),
502+ vals = np .random .uniform (low = 0.5 , high = 1.0 , size = 32 ),
503+ ),
504+ helper .make_tensor (
505+ name = "bn1_var" ,
506+ data_type = onnx .TensorProto .FLOAT ,
507+ dims = (32 ,),
508+ vals = np .random .uniform (low = 0.5 , high = 1.0 , size = 32 ),
509+ ),
510+ helper .make_tensor (
511+ name = "weights_3" ,
512+ data_type = onnx .TensorProto .FLOAT ,
513+ dims = (32 , 32 , 3 , 3 ),
514+ vals = np .random .uniform (low = 0.5 , high = 1.0 , size = 32 * 32 * 3 * 3 ),
515+ ),
516+ helper .make_tensor (
517+ name = "bn2_scale" ,
518+ data_type = onnx .TensorProto .FLOAT ,
519+ dims = (32 ,),
520+ vals = np .random .uniform (low = 0.5 , high = 1.0 , size = 32 ),
521+ ),
522+ helper .make_tensor (
523+ name = "bn2_bias" ,
524+ data_type = onnx .TensorProto .FLOAT ,
525+ dims = (32 ,),
526+ vals = np .random .uniform (low = 0.5 , high = 1.0 , size = 32 ),
527+ ),
528+ helper .make_tensor (
529+ name = "bn2_mean" ,
530+ data_type = onnx .TensorProto .FLOAT ,
531+ dims = (32 ,),
532+ vals = np .random .uniform (low = 0.5 , high = 1.0 , size = 32 ),
533+ ),
534+ helper .make_tensor (
535+ name = "bn2_var" ,
536+ data_type = onnx .TensorProto .FLOAT ,
537+ dims = (32 ,),
538+ vals = np .random .uniform (low = 0.5 , high = 1.0 , size = 32 ),
539+ ),
540+ ]
541+
542+ # Create the ONNX graph with the nodes and initializers
543+ graph = helper .make_graph (
544+ nodes , "convtranspose_conv_residual" , inputs , outputs , initializer = initializers
545+ )
546+
547+ # Create the ONNX model
548+ model = helper .make_model (graph )
549+ model .opset_import [0 ].version = 13
550+ model .ir_version = 10
551+
552+ # Check the ONNX model
553+ model_inferred = onnx .shape_inference .infer_shapes (model )
554+ onnx .checker .check_model (model_inferred )
555+
556+ return model_inferred
0 commit comments