Skip to content

Commit 04d7fa2

Browse files
authored
Merge pull request #667 from NVIDIA/kwargs_py_api
kwargs API and fix for partitioning when weights are provided in FP16
2 parents f49604d + 4a12861 commit 04d7fa2

File tree

89 files changed

+1897
-928
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+1897
-928
lines changed

core/compiler.cpp

Lines changed: 141 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -128,22 +128,6 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
128128
return conversion::VerifyConverterSupportForBlock(g->block());
129129
}
130130

131-
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
132-
// Go through Lowering to simplify graph and extract weight parameters
133-
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);
134-
135-
auto convert_cfg = std::move(cfg.convert_info);
136-
auto g = graph_and_parameters.first;
137-
138-
auto params = graph_and_parameters.second;
139-
auto named_params = conversion::get_named_params(g->inputs(), params);
140-
141-
LOG_INFO(*g << "(CompileGraph)\n");
142-
143-
auto engine = conversion::ConvertBlockToEngine(g->block(), convert_cfg, named_params);
144-
return std::move(engine);
145-
}
146-
147131
void AddSegmentedBlockToGraph(
148132
std::shared_ptr<torch::jit::Graph>& g,
149133
partitioning::SegmentedBlock& seg,
@@ -237,15 +221,15 @@ void AddIfBlockToGraph(
237221
GraphAndMapping ConstructFallbackGraph(
238222
torch::jit::script::Module& new_mod,
239223
torch::jit::Block* block,
240-
std::unordered_map<torch::jit::Value*, torch::jit::IValue> input_ivalues_map,
224+
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map,
241225
CompileSpec cfg,
242-
conversion::GraphParams named_params) {
226+
ir::StaticParams static_params) {
243227
auto convert_cfg = cfg.convert_info;
244228
auto partition_info = cfg.partition_info;
245229

246230
auto new_g = std::make_shared<torch::jit::Graph>();
247231

248-
auto segmented_blocks = partitioning::Partition(block, input_ivalues_map, partition_info);
232+
auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info);
249233

250234
// the mapping from lowering graph => fallback global graph
251235
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
@@ -259,13 +243,18 @@ GraphAndMapping ConstructFallbackGraph(
259243
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
260244

261245
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
246+
auto shapes = seg_block.in_shapes();
247+
auto types = seg_block.in_types();
262248
std::vector<ir::Input> inputs;
263-
for (auto& shape : seg_block.in_shape()) {
264-
inputs.push_back(ir::Input(shape));
249+
for (size_t i = 0; i < shapes.size(); i++) {
250+
auto in = ir::Input(shapes[i]);
251+
in.dtype = util::ScalarTypeToTRTDataType(types[i]);
252+
inputs.push_back(in);
265253
}
266254
// update the input ranges for each segments
267-
convert_cfg.inputs = inputs;
268-
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
255+
convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
256+
257+
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
269258
auto temp_g = std::make_shared<torch::jit::Graph>();
270259
auto device_spec = convert_cfg.engine_settings.device;
271260
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
@@ -281,7 +270,7 @@ GraphAndMapping ConstructFallbackGraph(
281270
std::vector<GraphAndMapping> graph_and_mappings;
282271
for (auto cur_block : if_node->blocks()) {
283272
graph_and_mappings.push_back(
284-
ConstructFallbackGraph(new_mod, cur_block, input_ivalues_map, cfg, named_params));
273+
ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params));
285274
}
286275
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);
287276

@@ -299,88 +288,157 @@ GraphAndMapping ConstructFallbackGraph(
299288
return {new_g, old_to_new_g};
300289
}
301290

302-
torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Module& mod, CompileSpec cfg) {
303-
// TODO: Should be doing a functional transform but need PR #31978
304-
// [jit] More robust mangling
305-
// torch::jit::script::Module new_mod = mod.clone();
306-
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
307-
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
308-
for (const torch::jit::script::Method& method : mod.get_methods()) {
309-
// Compile only forward methods. forward method contains the entire graph.
310-
if (method.name().compare("forward") == 0) {
311-
auto new_g = std::make_shared<torch::jit::Graph>();
312-
auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.lower_info);
291+
void MapInputsAndDetermineDTypes(
292+
CompileSpec& cfg,
293+
std::shared_ptr<torch::jit::Graph>& g,
294+
ir::StaticParams& static_params,
295+
ir::TypeMap& first_use_type_map) {
296+
// Associate input specs with inputs
297+
cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));
298+
299+
for (auto& in : g->inputs()) {
300+
auto est_type_opt = first_use_type_map.find(in)->second;
301+
ir::Input& spec = cfg.convert_info.inputs.find(in)->second;
302+
if (est_type_opt && !spec.dtype_is_user_defined) {
303+
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated
304+
// type
305+
LOG_INFO(
306+
"Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
307+
<< in->debugName() << " has type " << est_type_opt.value()
308+
<< ". If this is incorrect explicitly set dtype for input and file a bug");
309+
spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value());
310+
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
311+
// If we cannot calculate the type and the user did not define the type, then default to FP32
312+
LOG_WARNING(
313+
"Cannot infer input type from calcuations in graph for input "
314+
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
315+
spec.dtype = nvinfer1::DataType::kFLOAT;
316+
} else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) {
317+
if (!est_type_opt) {
318+
LOG_INFO("Cannot infer input tensor dtype in graph, unable to verify user input dtype settings");
319+
} else {
320+
if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
321+
std::stringstream ss;
322+
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
323+
ss << cfg.convert_info.inputs.find(in)->second.dtype;
324+
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
325+
ss << est_type_opt.value() << std::endl;
326+
ss << "The compiler is going to use the user setting " << cfg.convert_info.inputs.find(in)->second.dtype;
327+
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
328+
ss << "compatibility with PyTorch's data type convention is required.\n";
329+
ss << "If you do indeed see errors at runtime either:\n";
330+
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
331+
ss << "- Disable partial compilation by setting require_full_compilation to True";
332+
auto warn_str = ss.str();
333+
LOG_WARNING(warn_str);
334+
// Overwrite type map with user settings
335+
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
336+
}
337+
}
338+
} else {
339+
// The user defined the type so no changes are necessary
340+
}
341+
}
342+
}
313343

314-
auto g = graph_and_parameters.first;
315-
auto params = graph_and_parameters.second;
316-
auto named_params = conversion::get_named_params(g->inputs(), params);
317-
LOG_INFO("(LoweredGraph)\n" << *g);
344+
uint64_t GetRecommendedWorkspaceSize(const runtime::CudaDevice& device) {
345+
if (device.major < 6) {
346+
return 256 * (1 << 20);
347+
} else {
348+
return 1 << 30;
349+
}
350+
}
318351

319-
std::unordered_map<torch::jit::Value*, ir::Input> inputs;
320-
for (size_t i = 0; i < g->inputs().size(); ++i) {
321-
inputs.insert({g->inputs()[i], cfg.convert_info.inputs[i]});
322-
}
323-
auto input_ivalues_map = partitioning::generateRandomInputs(inputs);
324-
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, named_params);
325-
new_g = graph_and_mapping.first;
326-
LOG_INFO("(FallbackGraph)\n" << *new_g);
352+
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
353+
// Go through Lowering to simplify graph and extract weight parameters
354+
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);
327355

328-
// if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
329-
// module
330-
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
331-
LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n");
332-
return mod;
333-
}
356+
auto g = graph_and_parameters.first;
357+
TRTORCH_CHECK(
358+
conversion::VerifyConverterSupportForBlock(g->block()),
359+
"Not all operations in graph are supported by the compiler");
360+
auto params = graph_and_parameters.second;
361+
auto static_params = ir::get_static_params(g->inputs(), params);
362+
// Infer the type of an input from the weights of the calculation
363+
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
334364

335-
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
336-
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
337-
new_mod.type()->addMethod(new_method);
338-
new_method->setSchema(schema);
339-
}
365+
// GPU default WS size : 1 GB
366+
// Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
367+
auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
368+
auto device_spec = cfg.convert_info.engine_settings.device;
369+
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
370+
if (workspace_size == 0) {
371+
cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device);
340372
}
341373

342-
return new_mod;
374+
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
375+
376+
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
377+
378+
return std::move(engine);
343379
}
344380

345-
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
346-
// TODO: not sure how to deal with duplicated code here, so just cut out a branch temporally
347-
if (cfg.partition_info.enabled) {
348-
return CompileGraphWithFallback(mod, cfg);
349-
}
350-
auto device_spec = cfg.convert_info.engine_settings.device;
381+
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) {
382+
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt");
351383

352384
// GPU default WS size : 1 GB
353385
// Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
354386
auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
355-
cudaDeviceProp device_prop;
356-
cudaGetDeviceProperties(&device_prop, device_spec.gpu_id);
387+
auto device_spec = cfg.convert_info.engine_settings.device;
388+
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
357389
if (workspace_size == 0) {
358-
if (device_prop.major < 6) {
359-
cfg.convert_info.engine_settings.workspace_size = 256 * (1 << 20);
360-
} else {
361-
cfg.convert_info.engine_settings.workspace_size = 1 << 30;
362-
}
390+
cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device);
363391
}
364392

365-
// TODO: Should be doing a functional transform but need PR #31978
366-
// [jit] More robust mangling
367-
// torch::jit::script::Module new_mod = mod.clone();
368-
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
369-
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
370-
for (const torch::jit::script::Method& method : mod.get_methods()) {
371-
// Compile only forward methods. forward method contains the entire graph.
393+
for (const torch::jit::Method& method : mod.get_methods()) {
372394
if (method.name().compare("forward") == 0) {
373-
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
374395
auto new_g = std::make_shared<torch::jit::Graph>();
375-
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
376-
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
396+
397+
auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.lower_info);
398+
399+
auto g = graph_and_parameters.first;
400+
auto params = graph_and_parameters.second;
401+
auto static_params = ir::get_static_params(g->inputs(), params);
402+
// Infer the type of an input from the weights of the calculation
403+
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
404+
405+
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
406+
407+
if (cfg.partition_info.enabled &&
408+
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
409+
cfg.partition_info.forced_fallback_operators.size() == 0 &&
410+
conversion::VerifyConverterSupportForBlock(g->block(), true))) {
411+
LOG_INFO("Skipping partitioning since model is fully supported");
412+
}
413+
414+
if (cfg.partition_info.enabled &&
415+
!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
416+
cfg.partition_info.forced_fallback_operators.size() == 0 &&
417+
conversion::VerifyConverterSupportForBlock(g->block(), false))) {
418+
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
419+
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params);
420+
new_g = graph_and_mapping.first;
421+
LOG_INFO("Segmented Graph: " << *new_g);
422+
423+
// if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
424+
// module
425+
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
426+
LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n");
427+
return mod;
428+
}
429+
} else {
430+
TRTORCH_CHECK(
431+
conversion::VerifyConverterSupportForBlock(g->block()),
432+
"Not all operations in graph are supported by the compiler");
433+
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
434+
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
435+
}
377436
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
378437
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
379438
new_mod.type()->addMethod(new_method);
380439
new_method->setSchema(schema);
381440
}
382441
}
383-
384442
return new_mod;
385443
}
386444

core/compiler.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ namespace trtorch {
1313
namespace core {
1414

1515
struct CompileSpec {
16-
CompileSpec(std::vector<ir::Input> inputs) : convert_info(std::move(inputs)) {}
16+
CompileSpec(std::vector<ir::Input> inputs) : inputs(inputs) {}
17+
std::vector<ir::Input> inputs;
1718
conversion::ConversionInfo convert_info;
1819
lowering::LowerInfo lower_info;
1920
partitioning::PartitionInfo partition_info;

core/conversion/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ config_setting(
1010
cc_library(
1111
name = "conversion",
1212
srcs = [
13-
"InterfaceTypes.cpp",
1413
"conversion.cpp",
1514
"conversion_ignorelist.cpp",
1615
],

0 commit comments

Comments
 (0)