@@ -128,22 +128,6 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
128
128
return conversion::VerifyConverterSupportForBlock (g->block ());
129
129
}
130
130
131
- std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
132
- // Go through Lowering to simplify graph and extract weight parameters
133
- auto graph_and_parameters = lowering::Lower (mod, method_name, cfg.lower_info );
134
-
135
- auto convert_cfg = std::move (cfg.convert_info );
136
- auto g = graph_and_parameters.first ;
137
-
138
- auto params = graph_and_parameters.second ;
139
- auto named_params = conversion::get_named_params (g->inputs (), params);
140
-
141
- LOG_INFO (*g << " (CompileGraph)\n " );
142
-
143
- auto engine = conversion::ConvertBlockToEngine (g->block (), convert_cfg, named_params);
144
- return std::move (engine);
145
- }
146
-
147
131
void AddSegmentedBlockToGraph (
148
132
std::shared_ptr<torch::jit::Graph>& g,
149
133
partitioning::SegmentedBlock& seg,
@@ -237,15 +221,15 @@ void AddIfBlockToGraph(
237
221
GraphAndMapping ConstructFallbackGraph (
238
222
torch::jit::script::Module& new_mod,
239
223
torch::jit::Block* block,
240
- std::unordered_map<torch::jit::Value*, torch::jit::IValue> input_ivalues_map ,
224
+ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map ,
241
225
CompileSpec cfg,
242
- conversion::GraphParams named_params ) {
226
+ ir::StaticParams static_params ) {
243
227
auto convert_cfg = cfg.convert_info ;
244
228
auto partition_info = cfg.partition_info ;
245
229
246
230
auto new_g = std::make_shared<torch::jit::Graph>();
247
231
248
- auto segmented_blocks = partitioning::Partition (block, input_ivalues_map , partition_info);
232
+ auto segmented_blocks = partitioning::Partition (block, example_tensor_map , partition_info);
249
233
250
234
// the mapping from lowering graph => fallback global graph
251
235
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
@@ -259,13 +243,18 @@ GraphAndMapping ConstructFallbackGraph(
259
243
trt_engine_id << reinterpret_cast <const int *>(&seg_block);
260
244
261
245
if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
246
+ auto shapes = seg_block.in_shapes ();
247
+ auto types = seg_block.in_types ();
262
248
std::vector<ir::Input> inputs;
263
- for (auto & shape : seg_block.in_shape ()) {
264
- inputs.push_back (ir::Input (shape));
249
+ for (size_t i = 0 ; i < shapes.size (); i++) {
250
+ auto in = ir::Input (shapes[i]);
251
+ in.dtype = util::ScalarTypeToTRTDataType (types[i]);
252
+ inputs.push_back (in);
265
253
}
266
254
// update the input ranges for each segments
267
- convert_cfg.inputs = inputs;
268
- auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, named_params);
255
+ convert_cfg.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
256
+
257
+ auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, static_params);
269
258
auto temp_g = std::make_shared<torch::jit::Graph>();
270
259
auto device_spec = convert_cfg.engine_settings .device ;
271
260
auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
@@ -281,7 +270,7 @@ GraphAndMapping ConstructFallbackGraph(
281
270
std::vector<GraphAndMapping> graph_and_mappings;
282
271
for (auto cur_block : if_node->blocks ()) {
283
272
graph_and_mappings.push_back (
284
- ConstructFallbackGraph (new_mod, cur_block, input_ivalues_map , cfg, named_params ));
273
+ ConstructFallbackGraph (new_mod, cur_block, example_tensor_map , cfg, static_params ));
285
274
}
286
275
AddIfBlockToGraph (new_g, if_node, graph_and_mappings, old_to_new_g);
287
276
@@ -299,88 +288,157 @@ GraphAndMapping ConstructFallbackGraph(
299
288
return {new_g, old_to_new_g};
300
289
}
301
290
302
- torch::jit::script::Module CompileGraphWithFallback (const torch::jit::script::Module& mod, CompileSpec cfg) {
303
- // TODO: Should be doing a functional transform but need PR #31978
304
- // [jit] More robust mangling
305
- // torch::jit::script::Module new_mod = mod.clone();
306
- torch::jit::script::Module new_mod (mod._ivalue ()->name () + " _trt" );
307
- std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
308
- for (const torch::jit::script::Method& method : mod.get_methods ()) {
309
- // Compile only forward methods. forward method contains the entire graph.
310
- if (method.name ().compare (" forward" ) == 0 ) {
311
- auto new_g = std::make_shared<torch::jit::Graph>();
312
- auto graph_and_parameters = lowering::Lower (mod, method.name (), cfg.lower_info );
291
+ void MapInputsAndDetermineDTypes (
292
+ CompileSpec& cfg,
293
+ std::shared_ptr<torch::jit::Graph>& g,
294
+ ir::StaticParams& static_params,
295
+ ir::TypeMap& first_use_type_map) {
296
+ // Associate input specs with inputs
297
+ cfg.convert_info .inputs = std::move (ir::associate_specs_with_inputs (g, cfg.inputs , static_params));
298
+
299
+ for (auto & in : g->inputs ()) {
300
+ auto est_type_opt = first_use_type_map.find (in)->second ;
301
+ ir::Input& spec = cfg.convert_info .inputs .find (in)->second ;
302
+ if (est_type_opt && !spec.dtype_is_user_defined ) {
303
+ // If we can calculate the type from the graph and the type was not defined by the user then use the calculated
304
+ // type
305
+ LOG_INFO (
306
+ " Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
307
+ << in->debugName () << " has type " << est_type_opt.value ()
308
+ << " . If this is incorrect explicitly set dtype for input and file a bug" );
309
+ spec.dtype = util::ScalarTypeToTRTDataType (est_type_opt.value ());
310
+ } else if (!est_type_opt && !spec.dtype_is_user_defined ) {
311
+ // If we cannot calculate the type and the user did not define the type, then default to FP32
312
+ LOG_WARNING (
313
+ " Cannot infer input type from calcuations in graph for input "
314
+ << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
315
+ spec.dtype = nvinfer1::DataType::kFLOAT ;
316
+ } else if (spec.dtype_is_user_defined && cfg.partition_info .enabled ) {
317
+ if (!est_type_opt) {
318
+ LOG_INFO (" Cannot infer input tensor dtype in graph, unable to verify user input dtype settings" );
319
+ } else {
320
+ if (util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype ) != est_type_opt.value ()) {
321
+ std::stringstream ss;
322
+ ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
323
+ ss << cfg.convert_info .inputs .find (in)->second .dtype ;
324
+ ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
325
+ ss << est_type_opt.value () << std::endl;
326
+ ss << " The compiler is going to use the user setting " << cfg.convert_info .inputs .find (in)->second .dtype ;
327
+ ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
328
+ ss << " compatibility with PyTorch's data type convention is required.\n " ;
329
+ ss << " If you do indeed see errors at runtime either:\n " ;
330
+ ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
331
+ ss << " - Disable partial compilation by setting require_full_compilation to True" ;
332
+ auto warn_str = ss.str ();
333
+ LOG_WARNING (warn_str);
334
+ // Overwrite type map with user settings
335
+ first_use_type_map[in] = {util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype )};
336
+ }
337
+ }
338
+ } else {
339
+ // The user defined the type so no changes are necessary
340
+ }
341
+ }
342
+ }
313
343
314
- auto g = graph_and_parameters.first ;
315
- auto params = graph_and_parameters.second ;
316
- auto named_params = conversion::get_named_params (g->inputs (), params);
317
- LOG_INFO (" (LoweredGraph)\n " << *g);
344
+ uint64_t GetRecommendedWorkspaceSize (const runtime::CudaDevice& device) {
345
+ if (device.major < 6 ) {
346
+ return 256 * (1 << 20 );
347
+ } else {
348
+ return 1 << 30 ;
349
+ }
350
+ }
318
351
319
- std::unordered_map<torch::jit::Value*, ir::Input> inputs;
320
- for (size_t i = 0 ; i < g->inputs ().size (); ++i) {
321
- inputs.insert ({g->inputs ()[i], cfg.convert_info .inputs [i]});
322
- }
323
- auto input_ivalues_map = partitioning::generateRandomInputs (inputs);
324
- auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map, cfg, named_params);
325
- new_g = graph_and_mapping.first ;
326
- LOG_INFO (" (FallbackGraph)\n " << *new_g);
352
+ std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
353
+ // Go through Lowering to simplify graph and extract weight parameters
354
+ auto graph_and_parameters = lowering::Lower (mod, method_name, cfg.lower_info );
327
355
328
- // if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
329
- // module
330
- if (new_g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
331
- LOG_WARNING (" Didn't generate any TensorRT engines, the compiler did nothing\n " );
332
- return mod;
333
- }
356
+ auto g = graph_and_parameters.first ;
357
+ TRTORCH_CHECK (
358
+ conversion::VerifyConverterSupportForBlock (g->block ()),
359
+ " Not all operations in graph are supported by the compiler" );
360
+ auto params = graph_and_parameters.second ;
361
+ auto static_params = ir::get_static_params (g->inputs (), params);
362
+ // Infer the type of an input from the weights of the calculation
363
+ auto first_use_types = ir::get_block_first_calc_dtypes_opt (g->block ());
334
364
335
- auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
336
- auto schema = util::GenerateGraphSchema (new_method->name (), new_g);
337
- new_mod.type ()->addMethod (new_method);
338
- new_method->setSchema (schema);
339
- }
365
+ // GPU default WS size : 1 GB
366
+ // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
367
+ auto workspace_size = cfg.convert_info .engine_settings .workspace_size ;
368
+ auto device_spec = cfg.convert_info .engine_settings .device ;
369
+ auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
370
+ if (workspace_size == 0 ) {
371
+ cfg.convert_info .engine_settings .workspace_size = GetRecommendedWorkspaceSize (cuda_device);
340
372
}
341
373
342
- return new_mod;
374
+ MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
375
+
376
+ auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
377
+
378
+ return std::move (engine);
343
379
}
344
380
345
- torch::jit::script::Module CompileGraph (const torch::jit::script::Module& mod, CompileSpec cfg) {
346
- // TODO: not sure how to deal with duplicated code here, so just cut out a branch temporally
347
- if (cfg.partition_info .enabled ) {
348
- return CompileGraphWithFallback (mod, cfg);
349
- }
350
- auto device_spec = cfg.convert_info .engine_settings .device ;
381
+ torch::jit::Module CompileGraph (const torch::jit::Module& mod, CompileSpec cfg) {
382
+ torch::jit::Module new_mod (mod._ivalue ()->name () + " _trt" );
351
383
352
384
// GPU default WS size : 1 GB
353
385
// Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
354
386
auto workspace_size = cfg.convert_info .engine_settings .workspace_size ;
355
- cudaDeviceProp device_prop ;
356
- cudaGetDeviceProperties (&device_prop , device_spec.gpu_id );
387
+ auto device_spec = cfg. convert_info . engine_settings . device ;
388
+ auto cuda_device = runtime::CudaDevice (device_spec. gpu_id , device_spec.device_type );
357
389
if (workspace_size == 0 ) {
358
- if (device_prop.major < 6 ) {
359
- cfg.convert_info .engine_settings .workspace_size = 256 * (1 << 20 );
360
- } else {
361
- cfg.convert_info .engine_settings .workspace_size = 1 << 30 ;
362
- }
390
+ cfg.convert_info .engine_settings .workspace_size = GetRecommendedWorkspaceSize (cuda_device);
363
391
}
364
392
365
- // TODO: Should be doing a functional transform but need PR #31978
366
- // [jit] More robust mangling
367
- // torch::jit::script::Module new_mod = mod.clone();
368
- torch::jit::script::Module new_mod (mod._ivalue ()->name () + " _trt" );
369
- std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
370
- for (const torch::jit::script::Method& method : mod.get_methods ()) {
371
- // Compile only forward methods. forward method contains the entire graph.
393
+ for (const torch::jit::Method& method : mod.get_methods ()) {
372
394
if (method.name ().compare (" forward" ) == 0 ) {
373
- auto engine = ConvertGraphToTRTEngine (mod, method.name (), cfg);
374
395
auto new_g = std::make_shared<torch::jit::Graph>();
375
- auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
376
- AddEngineToGraph (new_mod, new_g, engine, cuda_device);
396
+
397
+ auto graph_and_parameters = lowering::Lower (mod, method.name (), cfg.lower_info );
398
+
399
+ auto g = graph_and_parameters.first ;
400
+ auto params = graph_and_parameters.second ;
401
+ auto static_params = ir::get_static_params (g->inputs (), params);
402
+ // Infer the type of an input from the weights of the calculation
403
+ auto first_use_types = ir::get_block_first_calc_dtypes_opt (g->block ());
404
+
405
+ MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
406
+
407
+ if (cfg.partition_info .enabled &&
408
+ (cfg.lower_info .forced_fallback_modules .size () == 0 &&
409
+ cfg.partition_info .forced_fallback_operators .size () == 0 &&
410
+ conversion::VerifyConverterSupportForBlock (g->block (), true ))) {
411
+ LOG_INFO (" Skipping partitioning since model is fully supported" );
412
+ }
413
+
414
+ if (cfg.partition_info .enabled &&
415
+ !(cfg.lower_info .forced_fallback_modules .size () == 0 &&
416
+ cfg.partition_info .forced_fallback_operators .size () == 0 &&
417
+ conversion::VerifyConverterSupportForBlock (g->block (), false ))) {
418
+ auto input_ivalues_map = partitioning::generateRandomInputs (cfg.convert_info .inputs , first_use_types);
419
+ auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map, cfg, static_params);
420
+ new_g = graph_and_mapping.first ;
421
+ LOG_INFO (" Segmented Graph: " << *new_g);
422
+
423
+ // if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
424
+ // module
425
+ if (new_g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
426
+ LOG_WARNING (" Didn't generate any TensorRT engines, the compiler did nothing\n " );
427
+ return mod;
428
+ }
429
+ } else {
430
+ TRTORCH_CHECK (
431
+ conversion::VerifyConverterSupportForBlock (g->block ()),
432
+ " Not all operations in graph are supported by the compiler" );
433
+ auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
434
+ AddEngineToGraph (new_mod, new_g, engine, cuda_device);
435
+ }
377
436
auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
378
437
auto schema = util::GenerateGraphSchema (new_method->name (), new_g);
379
438
new_mod.type ()->addMethod (new_method);
380
439
new_method->setSchema (schema);
381
440
}
382
441
}
383
-
384
442
return new_mod;
385
443
}
386
444
0 commit comments