@@ -31,11 +31,17 @@ void AddEngineToGraph(
31
31
torch::jit::script::Module mod,
32
32
std::shared_ptr<torch::jit::Graph>& g,
33
33
const std::string& serialized_engine,
34
- runtime::CudaDevice& device_info,
34
+ runtime::RTDevice& device_info,
35
+ const std::vector<std::string>& input_binding_names,
36
+ const std::vector<std::string>& output_binding_names,
35
37
std::string engine_id = " " ,
36
38
bool fallback = false ) {
37
39
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(
38
- mod._ivalue ()->name () + " _engine_" + engine_id, serialized_engine, device_info);
40
+ mod._ivalue ()->name () + " _engine_" + engine_id,
41
+ serialized_engine,
42
+ device_info,
43
+ input_binding_names,
44
+ output_binding_names);
39
45
// Get required metadata about the engine out
40
46
auto num_io = engine_ptr->num_io ;
41
47
auto name = engine_ptr->name ;
@@ -137,10 +143,13 @@ partitioning::GraphAndMapping BuildHybridGraph(
137
143
auto partitioning_info = cfg.partitioning_info ;
138
144
139
145
auto partitioning_ctx = partitioning::PartitioningCtx (block, partitioning_info);
140
- auto collection_input_ivalues_map =
141
- partitioning::generateRandomInputs (partitioning_info.collection_input_spec_map , first_use_types);
146
+ partitioning_ctx.input_types_map = first_use_types;
142
147
143
- partitioning::partition (&partitioning_ctx, collection_input_ivalues_map);
148
+ // Generate a dictionary of input torch::jit::Value's to their min, opt, max tensors and store in ctx
149
+ // TODO: Combine this within partition call
150
+ partitioning::populateInputIValues (&partitioning_ctx);
151
+
152
+ partitioning::partition (&partitioning_ctx);
144
153
145
154
for (auto & partitioned_block : partitioning_ctx.partitioned_blocks ) {
146
155
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second ;
@@ -151,23 +160,24 @@ partitioning::GraphAndMapping BuildHybridGraph(
151
160
trt_engine_id << reinterpret_cast <const int *>(&seg_block);
152
161
153
162
if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
154
- auto shapes = seg_block.in_shapes ();
155
- auto types = seg_block.in_types ();
156
- std::vector<ir::Input> inputs;
157
- for (size_t i = 0 ; i < shapes.size (); i++) {
158
- auto in = ir::Input (shapes[i]);
159
- in.dtype = util::ScalarTypeToTRTDataType (types[i]);
160
- inputs.push_back (in);
161
- }
163
+ auto inputs = seg_block.construct_inputs_spec ();
162
164
// update the input ranges for each segments
163
165
convert_info.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
164
166
165
167
// TODO mapping Inputs Ivalue to flatten one here
166
168
auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_info, static_params);
167
169
auto temp_g = std::make_shared<torch::jit::Graph>();
168
170
auto device_spec = convert_info.engine_settings .device ;
169
- auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
170
- AddEngineToGraph (new_mod, temp_g, engine, cuda_device, trt_engine_id.str (), true );
171
+ auto cuda_device = runtime::RTDevice (device_spec.gpu_id , device_spec.device_type );
172
+ AddEngineToGraph (
173
+ new_mod,
174
+ temp_g,
175
+ engine,
176
+ cuda_device,
177
+ std::vector<std::string>(),
178
+ std::vector<std::string>(),
179
+ trt_engine_id.str (),
180
+ true );
171
181
172
182
seg_block.update_graph (temp_g);
173
183
}
@@ -283,7 +293,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
283
293
torch::jit::Module new_mod (mod._ivalue ()->name () + " _trt" );
284
294
285
295
auto device_spec = cfg.convert_info .engine_settings .device ;
286
- auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
296
+ auto cuda_device = runtime::RTDevice (device_spec.gpu_id , device_spec.device_type );
287
297
288
298
for (const torch::jit::Method& method : mod.get_methods ()) {
289
299
if (method.name ().compare (" forward" ) == 0 ) {
@@ -331,7 +341,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
331
341
" Not all operations in graph are supported by the compiler" );
332
342
// TODO find the right
333
343
auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
334
- AddEngineToGraph (new_mod, new_g, engine, cuda_device);
344
+ AddEngineToGraph (new_mod, new_g, engine, cuda_device, std::vector<std::string>(), std::vector<std::string>() );
335
345
}
336
346
auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
337
347
auto schema = util::GenerateGraphSchema (new_method->name (), new_g);
@@ -342,12 +352,16 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
342
352
return new_mod;
343
353
}
344
354
345
- torch::jit::script::Module EmbedEngineInNewModule (const std::string& engine, runtime::CudaDevice cuda_device) {
355
+ torch::jit::script::Module EmbedEngineInNewModule (
356
+ const std::string& engine,
357
+ runtime::RTDevice cuda_device,
358
+ const std::vector<std::string>& input_binding_names,
359
+ const std::vector<std::string>& output_binding_names) {
346
360
std::ostringstream engine_id;
347
361
engine_id << reinterpret_cast <const int *>(&engine);
348
362
torch::jit::script::Module new_mod (" tensorrt_engine_mod_" + engine_id.str ());
349
363
auto new_g = std::make_shared<torch::jit::Graph>();
350
- AddEngineToGraph (new_mod, new_g, engine, cuda_device);
364
+ AddEngineToGraph (new_mod, new_g, engine, cuda_device, input_binding_names, output_binding_names );
351
365
auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (" forward" , new_g);
352
366
auto schema = util::GenerateGraphSchema (new_method->name (), new_g);
353
367
new_mod.type ()->addMethod (new_method);
0 commit comments