@@ -56,6 +56,7 @@ def convert_pt2(
5656 model : torch .nn .Module ,
5757 inputs : tuple [object , ...],
5858 quantizer : CadenceQuantizer ,
59+ dump_graphs : bool = False ,
5960) -> torch .fx .GraphModule :
6061 """
6162 Prepare and convert a model using the given quantizer.
@@ -86,6 +87,10 @@ def convert_pt2(
8687 .module ()
8788 )
8889
90+ if dump_graphs :
91+ logging .info ("Graph before quantization:" )
92+ logging .info (model_gm .graph .print_tabular ())
93+
8994 # Prepare
9095 prepared_model = prepare_pt2e (model_gm , quantizer )
9196
@@ -95,6 +100,10 @@ def convert_pt2(
95100 # Convert
96101 converted_model = convert_pt2e (prepared_model )
97102
103+ if dump_graphs :
104+ logging .info ("Graph after quantization (before fusion):" )
105+ logging .info (model_gm .graph .print_tabular ())
106+
98107 return converted_model
99108
100109
@@ -127,6 +136,7 @@ def quantize_pt2(
127136 model : torch .nn .Module ,
128137 inputs : tuple [object , ...],
129138 quantizer : Optional [CadenceQuantizer ] = None ,
139+ dump_graphs : bool = False ,
130140) -> torch .fx .GraphModule :
131141 """
132142 Prepare, convert and fuse the model using the given quantizer.
@@ -140,19 +150,22 @@ def quantize_pt2(
140150 quantizer = CadenceDefaultQuantizer ()
141151
142152 # Get converted graph module
143- converted_gm = convert_pt2 (model , inputs , quantizer )
153+ converted_gm = convert_pt2 (model , inputs , quantizer , dump_graphs )
144154
145155 # Get fused model
146156 fused_gm = fuse_pt2 (converted_gm , quantizer )
147157
158+ if dump_graphs :
159+ logging .info ("Graph after quantization and fusion:" )
160+ logging .info (fused_gm .graph .print_tabular ())
161+
148162 return fused_gm
149163
150164
151165# Export the model and lower it to an ExportedProgram (in aten IR)
152166def export_program (
153167 model : torch .nn .Module ,
154168 inputs : tuple [object , ...],
155- dump_graphs : bool = False ,
156169) -> ExportedProgram :
157170 assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
158171
@@ -162,10 +175,6 @@ def export_program(
162175 # Export the model and return it.
163176 expo_program = export (model , inputs , strict = True )
164177
165- if dump_graphs :
166- logging .info ("Exported graph:" )
167- expo_program .graph_module .graph .print_tabular ()
168-
169178 return expo_program
170179
171180
@@ -179,7 +188,7 @@ def export_to_edge(
179188 assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
180189
181190 # Export the model into an ExportedProgram.
182- expo_program = export_program (model , inputs , dump_graphs = dump_graphs )
191+ expo_program = export_program (model , inputs )
183192
184193 # Call to_edge to convert the graph to edge IR.
185194 # Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
@@ -200,8 +209,10 @@ def export_to_edge(
200209 )
201210
202211 if dump_graphs :
203- logging .info ("Edge graph:" )
204- edge_prog_manager .exported_program ().graph_module .graph .print_tabular ()
212+ logging .info ("Graph after Edge lowering:" )
213+ logging .info (
214+ edge_prog_manager .exported_program ().graph_module .graph .print_tabular ()
215+ )
205216
206217 return edge_prog_manager
207218
0 commit comments