Skip to content

Commit bfa4c9a

Browse files
skip dummy inference and run_shape_analysis (#3212)
1 parent e2a27a0 commit bfa4c9a

File tree

15 files changed

+227
-167
lines changed

15 files changed

+227
-167
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -502,19 +502,24 @@ def save(
502502
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
503503
)
504504
else:
505+
if arg_inputs is not None:
506+
logger.warning(
507+
"Provided model is a torch.jit.ScriptModule, inputs or arg_inputs is not necessary during save."
508+
)
505509
torch.jit.save(module, file_path)
506510
elif module_type == _ModuleType.ep:
507511
if output_format == "torchscript":
508512
raise ValueError(
509513
"Provided model is a torch.export.ExportedProgram but the output_format specified is torchscript. Please verify the output_format"
510514
)
511515
else:
516+
if arg_inputs is not None:
517+
logger.warning(
518+
"Provided model is a torch.export.ExportedProgram, inputs or arg_inputs is not necessary during save, it uses the inputs or arg_inputs provided during export and compile"
519+
)
512520
torch.export.save(module, file_path)
513521
elif module_type == _ModuleType.fx:
514-
if arg_inputs is None:
515-
raise ValueError(
516-
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
517-
)
522+
518523
# The module type is torch.fx.GraphModule
519524
if output_format == "torchscript":
520525
module_ts = torch.jit.trace(
@@ -525,11 +530,19 @@ def save(
525530
if not retrace:
526531
from torch_tensorrt.dynamo._exporter import export
527532

528-
exp_program = export(module, arg_inputs, kwarg_inputs)
533+
if arg_inputs is not None:
534+
logger.warning(
535+
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
536+
)
537+
exp_program = export(module)
529538
torch.export.save(exp_program, file_path)
530539
else:
531540
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
532541

542+
if arg_inputs is None:
543+
raise ValueError(
544+
"Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
545+
)
533546
with enable_torchbind_tracing():
534547
exp_program = torch.export.export(
535548
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737
from torch_tensorrt.dynamo.utils import (
3838
get_flat_args_with_check,
39+
get_output_metadata,
3940
parse_graph_io,
4041
prepare_inputs,
4142
set_log_level,
@@ -302,7 +303,6 @@ def compile(
302303

303304
settings = CompilationSettings(**compilation_options)
304305
logger.info("Compilation Settings: %s\n", settings)
305-
306306
exported_program = pre_export_lowering(exported_program, settings)
307307
exported_program = exported_program.run_decompositions(
308308
get_decompositions(enable_experimental_decompositions)
@@ -433,6 +433,12 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
433433
if not settings.use_fast_partitioner:
434434
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module))
435435

436+
submodule_node_dict = {}
437+
for node in partitioned_module.graph.nodes:
438+
if "_run_on_acc" not in node.name:
439+
continue
440+
submodule_node_dict[node.name] = node
441+
436442
# Store TRT replicas of Torch subgraphs
437443
trt_modules = {}
438444
# Iterate over all components that can be accelerated
@@ -452,6 +458,26 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
452458
)
453459
continue
454460

461+
if name not in submodule_node_dict:
462+
raise ValueError(
463+
f"node_name: {name} does not exist in the submodule node dictionary"
464+
)
465+
466+
# set the submodule metadata back to the parent trt_module_node
467+
metadata_list = get_output_metadata(submodule)
468+
assert len(metadata_list) > 0
469+
metadata_keys = ["val", "tensor_meta"]
470+
for key in metadata_keys:
471+
if key not in submodule_node_dict[name].meta:
472+
meta_val_list = [
473+
metadata[key] for metadata in metadata_list if key in metadata
474+
]
475+
submodule_node_dict[name].meta[key] = meta_val_list
476+
logger.debug(
477+
f"Updated metadata for node: {name} with its corresponding submodule outputs"
478+
)
479+
break
480+
455481
subgraph_data = PerSubgraphData()
456482
subgraph_data.subgraph_name = name
457483
subgraph_data.subgraph_op_count = len(

py/torch_tensorrt/dynamo/_exporter.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
import operator
3-
from typing import Any, Dict, Optional, Sequence, Tuple, cast
3+
from typing import Any, Dict, Sequence, Tuple, cast
44

55
import torch
66
from torch._guards import detect_fake_mode
@@ -16,31 +16,24 @@
1616
OutputSpec,
1717
TensorArgument,
1818
)
19-
from torch_tensorrt.dynamo import partitioning
2019

2120

2221
def export(
2322
gm: torch.fx.GraphModule,
24-
inputs: Sequence[torch.Tensor],
25-
kwarg_inputs: Optional[dict[str, Any]] = None,
2623
) -> ExportedProgram:
2724
"""Export the result of TensorRT compilation into the desired output format.
2825
2926
Arguments:
3027
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
3128
inputs (torch.Tensor): Torch input tensors
3229
"""
33-
if kwarg_inputs is None:
34-
kwarg_inputs = {}
35-
patched_module = transform(gm, inputs, kwarg_inputs)
30+
patched_module = transform(gm)
3631
exp_program = create_trt_exp_program(patched_module)
3732
return exp_program
3833

3934

4035
def transform(
4136
gm: torch.fx.GraphModule,
42-
inputs: Sequence[torch.Tensor],
43-
kwarg_inputs: Optional[dict[str, Any]] = None,
4437
) -> torch.fx.GraphModule:
4538
"""
4639
Transforms the graphmodule by inlining Pytorch and TensorRT submodules.
@@ -55,14 +48,10 @@ def transform(
5548
"""
5649
# Make a copy the graph since this function transforms the input graph and changes it's attributes.
5750
# This transformed graph is meant to be consumed by `create_trt_exp_program`
58-
if kwarg_inputs is None:
59-
kwarg_inputs = {}
6051
gm = copy.deepcopy(gm)
61-
# Run shape analysis
62-
_, outputs_map = partitioning.run_shape_analysis(gm, inputs, kwarg_inputs)
6352

6453
# Inline TensorRT submodules
65-
inline_trt_modules(gm, outputs_map)
54+
inline_trt_modules(gm)
6655

6756
# Inline pytorch submodules
6857
inline_torch_modules(gm)
@@ -361,9 +350,7 @@ def create_trt_exp_program(
361350
return trt_exp_program
362351

363352

364-
def inline_trt_modules(
365-
gm: torch.fx.GraphModule, outputs_map: Dict[Any, Sequence[Any]]
366-
) -> torch.fx.GraphModule:
353+
def inline_trt_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
367354
"""
368355
Replace TRT submodules with trt engine nodes.
369356
"""
@@ -379,7 +366,11 @@ def inline_trt_modules(
379366
trt_module_node = trt_module_node[0]
380367
assert trt_module_node.args
381368

382-
num_outputs = len(outputs_map[trt_module_node.name])
369+
if "val" not in trt_module_node.meta:
370+
raise ValueError(
371+
f"trt_module_node: {trt_module_node.name} does not have the metadata which should be set during dynamo compile_module step."
372+
)
373+
num_outputs = len(trt_module_node.meta["val"])
383374
# Insert a call_function node to perform inference on TRT engine
384375
with gm.graph.inserting_before(trt_module_node):
385376
engine_name = f"{name}_engine"
@@ -390,19 +381,9 @@ def inline_trt_modules(
390381
torch.ops.tensorrt.execute_engine.default,
391382
(trt_module_node.args, engine_node),
392383
)
393-
trt_node.meta["val"] = []
384+
# set trt_node.meta with trt_module_node.meta
394385
assert num_outputs > 0
395-
# Generate meta data for TRT node (a FakeTensor with corresponding output shape)
396-
for idx in range(num_outputs):
397-
trt_node.meta["val"].append(
398-
cast(
399-
FakeTensor,
400-
torch.empty_strided(
401-
tuple(outputs_map[trt_module_node.name][idx]),
402-
tuple([1] * len(outputs_map[trt_module_node.name][idx])),
403-
),
404-
)
405-
)
386+
trt_node.meta["val"] = trt_module_node.meta["val"]
406387

407388
# meta["val"] should be a lighter version of a tensor. For eg: it should be a FakeTensor (with output shape and dtype properties)
408389
# Lighter version of a custom_obj is not defined clearly. meta["val"] does not have any type expectations but

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ def construct_refit_mapping(
7474

7575
output_dtypes = infer_module_output_dtypes(
7676
module,
77-
inputs,
78-
settings.device,
7977
truncate_double=settings.truncate_double,
8078
)
8179

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 9 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
import tensorrt as trt
77
import torch
8-
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
9-
from torch_tensorrt._Device import Device
108
from torch_tensorrt._enums import dtype
119
from torch_tensorrt._features import ENABLED_FEATURES
1210
from torch_tensorrt._Input import Input
@@ -17,58 +15,22 @@
1715
TRTInterpreterResult,
1816
)
1917
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
20-
from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs
18+
from torch_tensorrt.dynamo.utils import get_output_dtypes
2119

2220
logger = logging.getLogger(__name__)
2321

2422

2523
def infer_module_output_dtypes(
2624
module: torch.fx.GraphModule,
27-
inputs: Sequence[Input],
28-
device: Device,
29-
kwarg_inputs: Optional[dict[str, Any]] = None,
3025
truncate_double: bool = False,
3126
) -> List[dtype]:
3227
"""
33-
This function performs model inference to determine the output dtypes
34-
and truncates them accordingly. inputs can be either arg_inputs or flattened input list.
35-
If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input.
28+
This function get the output dtypes from node.meta['val'] which was set during dynamo compile_module step
29+
and truncates them accordingly.
3630
"""
37-
# TODO: We can also determine output dtypes from the module.graph based on node metadata.
38-
# However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata,
39-
# so we stick to the model inference approach currently.
40-
with unset_fake_temporarily():
41-
# Get the device on which the model exists
42-
# For large models, this can be done on CPU to save GPU memory allocation for TRT.
43-
device = get_model_device(module)
44-
torch_inputs = get_torch_inputs(inputs, device)
45-
if kwarg_inputs is None:
46-
kwarg_inputs = {}
47-
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
48-
module_outputs = module(*torch_inputs, **torch_kwarg_inputs)
49-
if not isinstance(module_outputs, (list, tuple)):
50-
module_outputs = [module_outputs]
51-
52-
# Int64 outputs can sometimes be generated from within other operators
53-
# such as aten.sum - such outputs can be truncated
54-
output_dtypes = []
55-
for output in module_outputs:
56-
output_ = output
57-
# We don't need to check if output is nested here because the input module will be flattened
58-
if not isinstance(output, torch.Tensor):
59-
if isinstance(output, str):
60-
raise ValueError(
61-
f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)"
62-
)
63-
else:
64-
output_ = torch.tensor(output)
65-
66-
if truncate_double and output_.dtype == dtype.float64:
67-
output_dtypes.append(dtype.float32)
68-
else:
69-
output_dtypes.append(dtype._from(output_.dtype))
70-
71-
return output_dtypes
31+
outputs = [node for node in module.graph.nodes if node.op == "output"]
32+
outputs = outputs[0].args
33+
return get_output_dtypes(outputs, truncate_double)
7234

7335

7436
def interpret_module_to_result(
@@ -91,22 +53,9 @@ def interpret_module_to_result(
9153
Returns:
9254
TRTInterpreterResult
9355
"""
94-
if arg_inputs is not None:
95-
output_dtypes = infer_module_output_dtypes(
96-
module,
97-
arg_inputs,
98-
settings.device,
99-
kwarg_inputs=kwarg_inputs,
100-
truncate_double=settings.truncate_double,
101-
)
102-
else:
103-
# args and kwargs are combined and flattened to one list
104-
output_dtypes = infer_module_output_dtypes(
105-
module,
106-
inputs,
107-
settings.device,
108-
truncate_double=settings.truncate_double,
109-
)
56+
output_dtypes = infer_module_output_dtypes(
57+
module, truncate_double=settings.truncate_double
58+
)
11059

11160
interpreter = TRTInterpreter(
11261
module,
Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import logging
2-
from typing import Callable, Tuple
32

43
import torch
54
from torch_tensorrt.dynamo._settings import CompilationSettings
65
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
76
clean_up_graph_after_modifications,
87
)
8+
from torch_tensorrt.dynamo.utils import get_metadata, set_metadata
99

1010
logger = logging.getLogger(__name__)
1111

@@ -14,33 +14,29 @@ def lower_linear(
1414
gm: torch.fx.GraphModule, settings: CompilationSettings
1515
) -> torch.fx.GraphModule:
1616
"""Replace aten.linear with an equivalent implementation which can be easily converted to TRT"""
17-
orig, replacement = linear_replacement()
18-
19-
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
20-
gm = clean_up_graph_after_modifications(gm)
21-
logger.debug(f"Graph after lowering linear:\n{gm.graph}")
22-
23-
return gm
24-
25-
26-
def linear_replacement() -> Tuple[
27-
torch.fx.GraphModule,
28-
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
29-
]:
30-
"""Constructs the original and replacement functions for linear"""
17+
orig_op = torch.ops.aten.addmm.default
18+
replacement_op = torch.ops.aten.linear.default
3119

3220
# Original graph
3321
def orig(
3422
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
3523
) -> torch.Tensor:
3624
W_T = torch.ops.aten.permute.default(weight, [1, 0])
37-
out = torch.ops.aten.addmm.default(bias, input, W_T)
25+
out = orig_op(bias, input, W_T)
3826
return out
3927

4028
# Replacement graph
4129
def replacement(
4230
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
4331
) -> torch.Tensor:
44-
return torch.ops.aten.linear.default(input, weight, bias)
32+
return replacement_op(input, weight, bias)
33+
34+
metadata = get_metadata(gm, orig_op)
35+
replaced_nodes = torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement)
36+
37+
if len(replaced_nodes) > 0:
38+
gm = clean_up_graph_after_modifications(gm)
39+
set_metadata(gm, replacement_op, metadata)
40+
logger.debug(f"Graph after lowering linear:\n{gm.graph}")
4541

46-
return orig, replacement
42+
return gm

0 commit comments

Comments
 (0)