@@ -372,3 +372,60 @@ def build_conv_concat_model():
372372    onnx .checker .check_model (model_inferred )
373373
374374    return  model_inferred 
375+ 
376+ 
377+ def  build_matmul_relu_model_ir_12 ():
378+     # Define your model inputs and outputs 
379+     input_names  =  ["input_0" ]
380+     output_names  =  ["output_0" ]
381+     input_shapes  =  [(1 , 1024 , 1024 )]
382+     output_shapes  =  [(1 , 1024 , 16 )]
383+ 
384+     inputs  =  [
385+         helper .make_tensor_value_info (input_name , onnx .TensorProto .FLOAT , input_shape )
386+         for  input_name , input_shape  in  zip (input_names , input_shapes )
387+     ]
388+     outputs  =  [
389+         helper .make_tensor_value_info (output_name , onnx .TensorProto .FLOAT , output_shape )
390+         for  output_name , output_shape  in  zip (output_names , output_shapes )
391+     ]
392+ 
393+     # Create the ONNX graph with the nodes 
394+     nodes  =  [
395+         helper .make_node (
396+             op_type = "MatMul" ,
397+             inputs = ["input_0" , "weights_1" ],
398+             outputs = ["matmul1_matmul/MatMul:0" ],
399+             name = "matmul1_matmul/MatMul" ,
400+         ),
401+         helper .make_node (
402+             op_type = "Relu" ,
403+             inputs = ["matmul1_matmul/MatMul:0" ],
404+             outputs = ["output_0" ],
405+             name = "relu1_relu/Relu" ,
406+         ),
407+     ]
408+ 
409+     # Create the ONNX initializers 
410+     initializers  =  [
411+         helper .make_tensor (
412+             name = "weights_1" ,
413+             data_type = onnx .TensorProto .FLOAT ,
414+             dims = (1024 , 16 ),
415+             vals = np .random .uniform (low = 0.5 , high = 1.0 , size = 1024  *  16 ),
416+         ),
417+     ]
418+ 
419+     # Create the ONNX graph with the nodes and initializers 
420+     graph  =  helper .make_graph (nodes , "r1a" , inputs , outputs , initializer = initializers )
421+ 
422+     # Create the ONNX model 
423+     model  =  helper .make_model (graph )
424+     model .opset_import [0 ].version  =  13 
425+     model .ir_version  =  12 
426+ 
427+     # Check the ONNX model 
428+     model_inferred  =  onnx .shape_inference .infer_shapes (model )
429+     onnx .checker .check_model (model_inferred )
430+ 
431+     return  model_inferred 
0 commit comments