Skip to content

Commit 2523948

Browse files
Ethan Ngmeta-codesync[bot]
authored andcommitted
Update fuse_pt2 to take and return an ExportedProgram (#15474)
Summary: Pull Request resolved: #15474 ModAI canonically expects quant fusion step (pre_edge_transforms) to take/return an ExportedProgram. Differential Revision: D85808310 Reviewed By: mcremon-meta
1 parent 0b248b7 commit 2523948

File tree

5 files changed

+94
-22
lines changed

5 files changed

+94
-22
lines changed

backends/cadence/aot/compiler.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
from executorch.exir.passes import ToOutVarPass
4040
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
41-
from executorch.exir.program._program import to_edge
41+
from executorch.exir.program._program import _transform, to_edge
4242

4343
from torch.export.exported_program import ExportedProgram
4444
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e
@@ -145,22 +145,22 @@ def convert_pt2(
145145
# fused model, to be able to get reference numerics.
146146
# If this does not apply, please use quantize_pt2 instead.
147147
def fuse_pt2(
148-
converted_graph_module: torch.fx.GraphModule,
148+
converted_program: ExportedProgram,
149149
quantizer: CadenceQuantizer,
150-
) -> torch.fx.GraphModule:
150+
) -> ExportedProgram:
151151
"""
152-
Fuse a converted graph module using the given quantizer.
152+
Fuse a converted exported program using the given quantizer.
153153
The quantizer must be the same as the one used to convert the model.
154154
If you do not expect that behavior, please use quantize_pt2 instead,
155155
which will instantiate a default quantizer for you if needed.
156-
Returns a GraphModule with the fused model.
156+
Returns an ExportedProgram with the fused model.
157157
"""
158158
# Get patterns and apply fusion of dq -> op -> q to qop
159159
# pyre-ignore[16]: no attribute
160160
patterns = [q.pattern for q in quantizer.quantizers]
161-
QuantFusion(patterns)(converted_graph_module)
161+
fused_program = _transform(converted_program, QuantFusion(patterns))
162162

163-
return converted_graph_module
163+
return fused_program
164164

165165

166166
# Note: quantizer is not optional here to force the user to supply a quantizer
@@ -210,7 +210,7 @@ def quantize_pt2(
210210
If calibration data is provided, it will be used to calibrate the model. If
211211
not, the inputs will be used for calibration instead, which is useful for
212212
unit tests but should not be used for end-to-end use cases.
213-
Returns a GraphModule with the quantized model.
213+
Returns an ExportedProgram with the quantized model.
214214
Note: this function should not be called directly in general. Please use
215215
quantize_and_export_to_executorch for most needs.
216216
"""
@@ -227,16 +227,15 @@ def quantize_pt2(
227227
dump_graphs=dump_graphs,
228228
)
229229

230-
# Get fused model
231-
fused_gm = fuse_pt2(converted_gm, quantizer)
230+
# Apply quant fusion to the exported program
231+
program = torch.export.export(converted_gm, inputs, strict=True)
232+
fused_program = fuse_pt2(program, quantizer)
232233

233234
if dump_graphs:
234235
logging.info("Graph after quantization and fusion:")
235-
logging.info(fused_gm.graph.print_tabular())
236+
logging.info(fused_program.graph_module.graph.print_tabular())
236237

237-
program = torch.export.export(fused_gm, inputs, strict=True)
238-
239-
return program
238+
return fused_program
240239

241240

242241
TO_EDGE_OP_EXCEPTION_LIST: list[torch._ops.OpOverload] = [

backends/cadence/aot/export_example.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,10 @@ def export_model(
6363
# Get reference outputs from converted model
6464
ref_outputs = converted_model(*example_inputs)
6565

66-
# Quantize the model (note: quantizer needs to be the same as
67-
# the one used in prepare_and_convert_pt2)
68-
quantized_model = fuse_pt2(converted_model, quantizer)
66+
ep = torch.export.export(converted_model, example_inputs, strict=True)
6967

70-
ep = torch.export.export(quantized_model, example_inputs, strict=True)
68+
# Fuse the quantized patterns on the exported program (note: quantizer needs to be the same as the one used in prepare_and_convert_pt2)
69+
ep = fuse_pt2(ep, quantizer)
7170

7271
# Get edge program after Cadence specific passes
7372
exec_prog: ExecutorchProgramManager = _lower_ep_to_cadence_gen_etrecord(

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from executorch.backends.cadence.aot.quantizer.utils import (
3535
check_out_zero_point_is_min_range,
36+
copy_node_metadata,
3637
create_zero_bias_int32,
3738
find_sequential_partitions_aten,
3839
get_conv_args,
@@ -395,10 +396,13 @@ def get_args_and_kwargs_mixed_w8a32_conv(
395396
torch.ops.aten.permute.default,
396397
(other_inputs[0], [0, 2, 1]), # NCL -> NLC
397398
)
399+
copy_node_metadata(transposed_inputs, other_inputs[0])
400+
398401
transposed_weights = graph_module.graph.call_function(
399402
torch.ops.aten.permute.default,
400403
(weights_inputs[0], [2, 0, 1]), # NCL -> LNC
401404
)
405+
copy_node_metadata(transposed_weights, weights_inputs[0])
402406

403407
args = (
404408
transposed_inputs,
@@ -582,6 +586,26 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
582586
torch.ops.aten.transpose.int,
583587
(weights_inputs[0], 0, 1),
584588
)
589+
if "val" in weights_inputs[0].meta:
590+
original_val = weights_inputs[0].meta["val"]
591+
fake_mode = original_val.fake_mode
592+
if fake_mode is not None:
593+
with fake_mode:
594+
transposed_val = torch.ops.aten.transpose.int(
595+
original_val, 0, 1
596+
)
597+
transposed_weights.meta["val"] = transposed_val
598+
else:
599+
transposed_shape = list(original_val.shape)
600+
transposed_shape[0], transposed_shape[1] = (
601+
transposed_shape[1],
602+
transposed_shape[0],
603+
)
604+
transposed_weights.meta["val"] = torch.zeros(
605+
transposed_shape, dtype=original_val.dtype
606+
)
607+
copy_node_metadata(transposed_weights, weights_inputs[0])
608+
585609
# Call linear with transposed weight
586610
args, kwargs = get_args_and_kwargs_linear(
587611
graph_module,
@@ -654,6 +678,19 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
654678

655679
legalize_graph(graph_module)
656680
graph_module.graph.eliminate_dead_code()
681+
nodes_list = list(graph_module.graph.nodes)
682+
683+
if len(nodes_list) > 0 and nodes_list[-1].op != "output":
684+
output_nodes = [n for n in nodes_list if n.op == "output"]
685+
output_arg = output_nodes[0].args[0]
686+
original_meta = output_nodes[0].meta.copy()
687+
688+
for out_node in output_nodes:
689+
graph_module.graph.erase_node(out_node)
690+
691+
new_output_node = graph_module.graph.output(output_arg)
692+
new_output_node.meta.update(original_meta)
693+
657694
graph_module.recompile()
658695
return PassResult(graph_module, True)
659696

backends/cadence/aot/quantizer/utils.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2525

2626

27+
def copy_node_metadata(dest_node: fx.Node, src_node: fx.Node) -> None:
28+
for key in ["nn_module_stack", "stack_trace", "source_fn_stack"]:
29+
if key in src_node.meta and src_node.meta[key]:
30+
dest_node.meta[key] = src_node.meta[key]
31+
32+
2733
def quantize_tensor_multiplier(
2834
requantize_scale_tensor: torch.Tensor,
2935
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -114,15 +120,45 @@ def create_zero_bias_int32(
114120
"""
115121
Creates a zero bias tensor with the shape of weight[0]
116122
"""
117-
attr_node = getattr(graph_module, weight_node.target)
123+
try:
124+
attr_node = getattr(graph_module, weight_node.target)
125+
except AttributeError:
126+
if "val" in weight_node.meta:
127+
attr_node = weight_node.meta["val"]
128+
else:
129+
param_dict = dict(graph_module.named_parameters())
130+
if weight_node.target in param_dict:
131+
attr_node = param_dict[weight_node.target]
132+
else:
133+
buffer_dict = dict(graph_module.named_buffers())
134+
if weight_node.target in buffer_dict:
135+
attr_node = buffer_dict[weight_node.target]
136+
else:
137+
raise AttributeError(
138+
f"Could not find weight tensor for node {weight_node.target}. "
139+
f"Node metadata keys: {list(weight_node.meta.keys())}"
140+
)
141+
118142
weight_shape = list(attr_node.shape)
119143
bias_shape = weight_shape[0]
120-
return graph_module.graph.call_function(
144+
new_node = graph_module.graph.call_function(
121145
torch.ops.aten.full.default,
122146
([bias_shape], 0.0),
123147
{"dtype": torch.int32},
124148
)
125149

150+
if "val" in weight_node.meta:
151+
fake_mode = weight_node.meta["val"].fake_mode
152+
if fake_mode is not None:
153+
with fake_mode:
154+
fake_bias = torch.zeros([bias_shape], dtype=torch.int32)
155+
new_node.meta["val"] = fake_bias
156+
else:
157+
new_node.meta["val"] = torch.zeros([bias_shape], dtype=torch.int32)
158+
copy_node_metadata(new_node, weight_node)
159+
160+
return new_node
161+
126162

127163
def get_bias_qparams(
128164
obs_or_fqs: List[ObserverOrFakeQuantize],

util/activation_memory_profiler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ def _get_module_hierarchy(node: torch.fx.Node) -> str:
4141
Get the module hierarchy of the given node.
4242
"""
4343
module_stack = node.meta.get("nn_module_stack")
44-
if module_stack is not None:
44+
if module_stack is not None and module_stack:
4545
module_values_list = list(module_stack.values())
46-
return module_values_list[-1][0]
46+
if module_values_list:
47+
return module_values_list[-1][0]
4748
return ""
4849

4950

0 commit comments

Comments
 (0)