@@ -64,36 +64,37 @@ TEST(TensorRTEngineOp, manual) {
64
64
65
65
LOG (INFO) << " create block desc" ;
66
66
framework::BlockDesc block_desc (&program, block_);
67
- LOG (INFO) << " create mul op" ;
68
- auto * mul = block_desc.AppendOp ();
69
- mul ->SetType (" mul" );
70
- mul ->SetInput (" X" , std::vector<std::string>({" x" })); // 2 x 4
71
- mul ->SetInput (" Y" , std::vector<std::string>({" y" })); // 4 x 6
72
- mul ->SetOutput (" Out" , std::vector<std::string>({" z" })); // 2 x 6
67
+ LOG (INFO) << " create fc op" ;
68
+ auto * fc0 = block_desc.AppendOp ();
69
+ fc0 ->SetType (" mul" );
70
+ fc0 ->SetInput (" X" , std::vector<std::string>({" x" })); // 4 x 1 x 1
71
+ fc0 ->SetInput (" Y" , std::vector<std::string>({" y" })); // 4 x 6
72
+ fc0 ->SetOutput (" Out" , std::vector<std::string>({" z" })); // 6 x 1 x 1
73
73
74
74
LOG (INFO) << " create fc op" ;
75
- auto * fc = block_desc.AppendOp ();
76
- fc ->SetType (" mul" );
77
- fc ->SetInput (" X" , std::vector<std::string>({" z" }));
78
- fc ->SetInput (" Y" , std::vector<std::string>({" y0" })); // 6 x 8
79
- fc ->SetOutput (" Out" , std::vector<std::string>({" z0" })); // 2 x 8
75
+ auto * fc1 = block_desc.AppendOp ();
76
+ fc1 ->SetType (" mul" );
77
+ fc1 ->SetInput (" X" , std::vector<std::string>({" z" }));
78
+ fc1 ->SetInput (" Y" , std::vector<std::string>({" y0" })); // 6 x 8
79
+ fc1 ->SetOutput (" Out" , std::vector<std::string>({" z0" })); // 8 x 1 x 1
80
80
81
81
// Set inputs' variable shape in BlockDesc
82
- AddTensorToBlockDesc (block_, " x" , std::vector<int64_t >({2 , 4 }));
82
+ // the batch size is 2, so the dims of 'x' is {2, 4, 1, 1}
83
+ AddTensorToBlockDesc (block_, " x" , std::vector<int64_t >({2 , 4 , 1 , 1 }));
83
84
AddTensorToBlockDesc (block_, " y" , std::vector<int64_t >({4 , 6 }));
84
85
AddTensorToBlockDesc (block_, " y0" , std::vector<int64_t >({6 , 8 }));
85
86
AddTensorToBlockDesc (block_, " z" , std::vector<int64_t >({2 , 6 }));
86
87
87
88
// It is wired, need to copy manually.
88
- *block_->add_ops () = *mul ->Proto ();
89
- *block_->add_ops () = *fc ->Proto ();
89
+ *block_->add_ops () = *fc0 ->Proto ();
90
+ *block_->add_ops () = *fc1 ->Proto ();
90
91
91
92
ASSERT_EQ (block_->ops_size (), 2 );
92
93
93
94
LOG (INFO) << " create tensorrt desc" ;
94
95
framework::OpDesc engine_op_desc (nullptr );
95
96
engine_op_desc.SetType (" tensorrt_engine" );
96
- engine_op_desc.SetInput (" Xs" , std::vector<std::string>({" x" , " y " , " y0 " }));
97
+ engine_op_desc.SetInput (" Xs" , std::vector<std::string>({" x" }));
97
98
engine_op_desc.SetOutput (" Ys" , std::vector<std::string>({" z0" }));
98
99
SetAttr<std::string>(engine_op_desc.Proto (), " subgraph" ,
99
100
block_->SerializeAsString ());
@@ -208,4 +209,3 @@ TEST(TensorRTEngineOp, fc) { Execute(40, 28, 28); }
208
209
} // namespace paddle
209
210
210
211
USE_TRT_CONVERTER (mul)
211
- USE_TRT_CONVERTER(fc)
0 commit comments