@@ -20,18 +20,18 @@ class PerSubgraphData:
20
20
Args:
21
21
subgraph_name (str): Name of the subgraph in the GraphModule
22
22
subgraph_op_count (int): Number of operations in the subgraph
23
- subgraph_input_shapes (Any): Shapes of input Tensors of the subgraph
24
- subgraph_input_dtypes (Any): Input data types of the subgraph
25
- subgraph_output_shapes (Any): Shapes of output Tensors of the subgraph
26
- subgraph_output_dtypes (Any): Output data types of the subgraph
23
+ input_shapes (Any): Shapes of input Tensors of the subgraph
24
+ input_dtypes (Any): Input data types of the subgraph
25
+ output_shapes (Any): Shapes of output Tensors of the subgraph
26
+ output_dtypes (Any): Output data types of the subgraph
27
27
"""
28
28
29
29
subgraph_name : str = ""
30
30
subgraph_op_count : int = 0
31
- subgraph_input_shapes : Any = field (default_factory = list )
32
- subgraph_input_dtypes : Any = field (default_factory = list )
33
- subgraph_output_shapes : Any = field (default_factory = list )
34
- subgraph_output_dtypes : Any = field (default_factory = list )
31
+ input_shapes : Any = field (default_factory = list )
32
+ input_dtypes : Any = field (default_factory = list )
33
+ output_shapes : Any = field (default_factory = list )
34
+ output_dtypes : Any = field (default_factory = list )
35
35
36
36
37
37
@dataclass
@@ -41,10 +41,10 @@ class DryRunTracker:
41
41
Args:
42
42
total_ops_in_graph (int): Total number of operators in graph
43
43
supported_ops_in_graph (int): Number of supported operators in graph
44
- graph_input_shapes (Any): Shapes of input Tensors of the graph
45
- graph_input_dtypes (Any): Input data types of the graph
46
- graph_output_shapes (Any): Shapes of output Tensors of the graph
47
- graph_output_dtypes (Any): Output data types of the graph
44
+ input_shapes (Any): Shapes of input Tensors of the graph
45
+ input_dtypes (Any): Input data types of the graph
46
+ output_shapes (Any): Shapes of output Tensors of the graph
47
+ output_dtypes (Any): Output data types of the graph
48
48
per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class
49
49
tensorrt_graph_count (int): Number of TensorRT engines to be generated
50
50
compilation_settings (CompilationSettings): User Compilation Settings
@@ -54,10 +54,10 @@ class DryRunTracker:
54
54
55
55
total_ops_in_graph : int = 0
56
56
supported_ops_in_graph : int = 0
57
- graph_input_shapes : Any = field (default_factory = list )
58
- graph_input_dtypes : Any = field (default_factory = list )
59
- graph_output_shapes : Any = field (default_factory = list )
60
- graph_output_dtypes : Any = field (default_factory = list )
57
+ input_shapes : Any = field (default_factory = list )
58
+ input_dtypes : Any = field (default_factory = list )
59
+ output_shapes : Any = field (default_factory = list )
60
+ output_dtypes : Any = field (default_factory = list )
61
61
per_subgraph_data : List [PerSubgraphData ] = field (default_factory = list )
62
62
tensorrt_graph_count : int = 0
63
63
compilation_settings : CompilationSettings = field (
@@ -111,7 +111,7 @@ def dryrun_stats_display(
111
111
formatted_stats += " " * 2 + "Graph Structure:\n \n "
112
112
formatted_stats += (
113
113
" " * 3
114
- + f"Inputs: { input_formatter (dryrun_tracker .graph_input_shapes , dryrun_tracker .graph_input_dtypes )} \n "
114
+ + f"Inputs: { input_formatter (dryrun_tracker .input_shapes , dryrun_tracker .input_dtypes )} \n "
115
115
)
116
116
117
117
for i , trt_subgraph_data in enumerate (dryrun_tracker .per_subgraph_data ):
@@ -122,21 +122,21 @@ def dryrun_stats_display(
122
122
)
123
123
formatted_stats += (
124
124
" " * 5
125
- + f"Engine Inputs: { input_formatter (trt_subgraph_data .subgraph_input_shapes , trt_subgraph_data .subgraph_input_dtypes )} \n "
125
+ + f"Engine Inputs: { input_formatter (trt_subgraph_data .input_shapes , trt_subgraph_data .input_dtypes )} \n "
126
126
)
127
127
formatted_stats += (
128
128
" " * 5
129
129
+ f"Number of Operators in Engine: { trt_subgraph_data .subgraph_op_count } \n "
130
130
)
131
131
formatted_stats += (
132
132
" " * 5
133
- + f"Engine Outputs: { input_formatter (trt_subgraph_data .subgraph_output_shapes , trt_subgraph_data .subgraph_output_dtypes )} \n "
133
+ + f"Engine Outputs: { input_formatter (trt_subgraph_data .output_shapes , trt_subgraph_data .output_dtypes )} \n "
134
134
)
135
135
136
136
formatted_stats += " " * 4 + "...\n "
137
137
formatted_stats += (
138
138
" " * 3
139
- + f"Outputs: { input_formatter (dryrun_tracker .graph_output_shapes , dryrun_tracker .graph_output_dtypes )} \n "
139
+ + f"Outputs: { input_formatter (dryrun_tracker .output_shapes , dryrun_tracker .output_dtypes )} \n "
140
140
)
141
141
142
142
# Print aggregate statistics about the graph structure, including recommended "min_block_size" options
@@ -225,11 +225,20 @@ def input_formatter(shapes: Any, dtypes: Any) -> str:
225
225
226
226
def input_formatter_helper (shapes : Any , dtypes : Any ) -> str :
227
227
"""Helper for input formatter"""
228
- # Base case - single shape, single dtype
229
- if isinstance (shapes , tuple ) and all (isinstance (elt , int ) for elt in shapes ):
230
- return f"Tensor: { shapes } @{ str (dtypes )[6 :]} , "
231
-
232
- # Base case - dynamic shape, single dtype
228
+ # Base case 1 - single static/dynamic shape, single dtype
229
+ if isinstance (shapes , tuple ) and all (
230
+ isinstance (elt , (int , tuple )) for elt in shapes
231
+ ):
232
+ input_shape_string = "Tensor: ("
233
+ for elt in shapes :
234
+ if isinstance (elt , tuple ):
235
+ input_shape_string += f"(min={ elt [0 ]} , max={ elt [1 ]} ), "
236
+ else :
237
+ input_shape_string += f"{ elt } , "
238
+ input_shape_string = input_shape_string [:- 2 ] + ")" + f"@{ str (dtypes )[6 :]} , "
239
+ return input_shape_string
240
+
241
+ # Base case 2 - dynamic shape, single dtype
233
242
elif (
234
243
isinstance (shapes , dict )
235
244
and len (shapes ) == 3
0 commit comments