Skip to content

Commit d9a8f2d

Browse files
authored
Update fuse_pt2 to take and return an ExportedProgram
Differential Revision: D85808310 Pull Request resolved: #15474
1 parent 1744514 commit d9a8f2d

File tree

5 files changed

+100
-22
lines changed

5 files changed

+100
-22
lines changed

backends/cadence/aot/compiler.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
from executorch.exir.passes import ToOutVarPass
4040
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
41-
from executorch.exir.program._program import to_edge
41+
from executorch.exir.program._program import _transform, to_edge
4242

4343
from torch.export.exported_program import ExportedProgram
4444
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e
@@ -145,22 +145,22 @@ def convert_pt2(
145145
# fused model, to be able to get reference numerics.
146146
# If this does not apply, please use quantize_pt2 instead.
147147
def fuse_pt2(
148-
converted_graph_module: torch.fx.GraphModule,
148+
converted_program: ExportedProgram,
149149
quantizer: CadenceQuantizer,
150-
) -> torch.fx.GraphModule:
150+
) -> ExportedProgram:
151151
"""
152-
Fuse a converted graph module using the given quantizer.
152+
Fuse a converted exported program using the given quantizer.
153153
The quantizer must be the same as the one used to convert the model.
154154
If you do not expect that behavior, please use quantize_pt2 instead,
155155
which will instantiate a default quantizer for you if needed.
156-
Returns a GraphModule with the fused model.
156+
Returns an ExportedProgram with the fused model.
157157
"""
158158
# Get patterns and apply fusion of dq -> op -> q to qop
159159
# pyre-ignore[16]: no attribute
160160
patterns = [q.pattern for q in quantizer.quantizers]
161-
QuantFusion(patterns)(converted_graph_module)
161+
fused_program = _transform(converted_program, QuantFusion(patterns))
162162

163-
return converted_graph_module
163+
return fused_program
164164

165165

166166
# Note: quantizer is not optional here to force the user to supply a quantizer
@@ -210,7 +210,7 @@ def quantize_pt2(
210210
If calibration data is provided, it will be used to calibrate the model. If
211211
not, the inputs will be used for calibration instead, which is useful for
212212
unit tests but should not be used for end-to-end use cases.
213-
Returns a GraphModule with the quantized model.
213+
Returns an ExportedProgram with the quantized model.
214214
Note: this function should not be called directly in general. Please use
215215
quantize_and_export_to_executorch for most needs.
216216
"""
@@ -227,16 +227,15 @@ def quantize_pt2(
227227
dump_graphs=dump_graphs,
228228
)
229229

230-
# Get fused model
231-
fused_gm = fuse_pt2(converted_gm, quantizer)
230+
# Apply quant fusion to the exported program
231+
program = torch.export.export(converted_gm, inputs, strict=True)
232+
fused_program = fuse_pt2(program, quantizer)
232233

233234
if dump_graphs:
234235
logging.info("Graph after quantization and fusion:")
235-
logging.info(fused_gm.graph.print_tabular())
236+
logging.info(fused_program.graph_module.graph.print_tabular())
236237

237-
program = torch.export.export(fused_gm, inputs, strict=True)
238-
239-
return program
238+
return fused_program
240239

241240

242241
TO_EDGE_OP_EXCEPTION_LIST: list[torch._ops.OpOverload] = [

backends/cadence/aot/export_example.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,10 @@ def export_model(
6363
# Get reference outputs from converted model
6464
ref_outputs = converted_model(*example_inputs)
6565

66-
# Quantize the model (note: quantizer needs to be the same as
67-
# the one used in prepare_and_convert_pt2)
68-
quantized_model = fuse_pt2(converted_model, quantizer)
66+
ep = torch.export.export(converted_model, example_inputs, strict=True)
6967

70-
ep = torch.export.export(quantized_model, example_inputs, strict=True)
68+
# Fuse the quantized patterns on the exported program (note: quantizer needs to be the same as the one used in prepare_and_convert_pt2)
69+
ep = fuse_pt2(ep, quantizer)
7170

7271
# Get edge program after Cadence specific passes
7372
exec_prog: ExecutorchProgramManager = _lower_ep_to_cadence_gen_etrecord(

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from executorch.backends.cadence.aot.quantizer.utils import (
3535
check_out_zero_point_is_min_range,
36+
copy_node_metadata,
3637
create_zero_bias_int32,
3738
find_sequential_partitions_aten,
3839
get_conv_args,
@@ -159,6 +160,8 @@ def get_args_and_kwargs_layer_norm(
159160
),
160161
{"dtype": torch.float32},
161162
)
163+
if len(inputs_inputs) > 0:
164+
copy_node_metadata(weight, inputs_inputs[0])
162165

163166
bias = other_inputs[2] if len(other_inputs) > 2 else None
164167

@@ -171,6 +174,8 @@ def get_args_and_kwargs_layer_norm(
171174
),
172175
{"dtype": torch.float32},
173176
)
177+
if len(inputs_inputs) > 0:
178+
copy_node_metadata(bias, inputs_inputs[0])
174179

175180
# Make the args and kwargs for the replacement op
176181
args = tuple(inputs_inputs + [scale, zero_point])
@@ -346,6 +351,8 @@ def get_args_and_kwargs_softmax(
346351
),
347352
{"dtype": torch.int32},
348353
)
354+
if len(inputs_inputs) > 0:
355+
copy_node_metadata(mask_tensor, inputs_inputs[0])
349356
# Make the scale and zero_point tensors
350357
in_scale = dequants_inputs[0].args[1]
351358
in_zero_point = dequants_inputs[0].args[2]
@@ -395,10 +402,13 @@ def get_args_and_kwargs_mixed_w8a32_conv(
395402
torch.ops.aten.permute.default,
396403
(other_inputs[0], [0, 2, 1]), # NCL -> NLC
397404
)
405+
copy_node_metadata(transposed_inputs, other_inputs[0])
406+
398407
transposed_weights = graph_module.graph.call_function(
399408
torch.ops.aten.permute.default,
400409
(weights_inputs[0], [2, 0, 1]), # NCL -> LNC
401410
)
411+
copy_node_metadata(transposed_weights, weights_inputs[0])
402412

403413
args = (
404414
transposed_inputs,
@@ -582,6 +592,26 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
582592
torch.ops.aten.transpose.int,
583593
(weights_inputs[0], 0, 1),
584594
)
595+
if "val" in weights_inputs[0].meta:
596+
original_val = weights_inputs[0].meta["val"]
597+
fake_mode = original_val.fake_mode
598+
if fake_mode is not None:
599+
with fake_mode:
600+
transposed_val = torch.ops.aten.transpose.int(
601+
original_val, 0, 1
602+
)
603+
transposed_weights.meta["val"] = transposed_val
604+
else:
605+
transposed_shape = list(original_val.shape)
606+
transposed_shape[0], transposed_shape[1] = (
607+
transposed_shape[1],
608+
transposed_shape[0],
609+
)
610+
transposed_weights.meta["val"] = torch.zeros(
611+
transposed_shape, dtype=original_val.dtype
612+
)
613+
copy_node_metadata(transposed_weights, weights_inputs[0])
614+
585615
# Call linear with transposed weight
586616
args, kwargs = get_args_and_kwargs_linear(
587617
graph_module,
@@ -654,6 +684,19 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
654684

655685
legalize_graph(graph_module)
656686
graph_module.graph.eliminate_dead_code()
687+
nodes_list = list(graph_module.graph.nodes)
688+
689+
if len(nodes_list) > 0 and nodes_list[-1].op != "output":
690+
output_nodes = [n for n in nodes_list if n.op == "output"]
691+
output_arg = output_nodes[0].args[0]
692+
original_meta = output_nodes[0].meta.copy()
693+
694+
for out_node in output_nodes:
695+
graph_module.graph.erase_node(out_node)
696+
697+
new_output_node = graph_module.graph.output(output_arg)
698+
new_output_node.meta.update(original_meta)
699+
657700
graph_module.recompile()
658701
return PassResult(graph_module, True)
659702

backends/cadence/aot/quantizer/utils.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2525

2626

27+
def copy_node_metadata(dest_node: fx.Node, src_node: fx.Node) -> None:
28+
for key in ["nn_module_stack", "stack_trace", "source_fn_stack"]:
29+
if key in src_node.meta and src_node.meta[key]:
30+
dest_node.meta[key] = src_node.meta[key]
31+
32+
2733
def quantize_tensor_multiplier(
2834
requantize_scale_tensor: torch.Tensor,
2935
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -114,15 +120,45 @@ def create_zero_bias_int32(
114120
"""
115121
Creates a zero bias tensor with the shape of weight[0]
116122
"""
117-
attr_node = getattr(graph_module, weight_node.target)
123+
try:
124+
attr_node = getattr(graph_module, weight_node.target)
125+
except AttributeError:
126+
if "val" in weight_node.meta:
127+
attr_node = weight_node.meta["val"]
128+
else:
129+
param_dict = dict(graph_module.named_parameters())
130+
if weight_node.target in param_dict:
131+
attr_node = param_dict[weight_node.target]
132+
else:
133+
buffer_dict = dict(graph_module.named_buffers())
134+
if weight_node.target in buffer_dict:
135+
attr_node = buffer_dict[weight_node.target]
136+
else:
137+
raise AttributeError(
138+
f"Could not find weight tensor for node {weight_node.target}. "
139+
f"Node metadata keys: {list(weight_node.meta.keys())}"
140+
)
141+
118142
weight_shape = list(attr_node.shape)
119143
bias_shape = weight_shape[0]
120-
return graph_module.graph.call_function(
144+
new_node = graph_module.graph.call_function(
121145
torch.ops.aten.full.default,
122146
([bias_shape], 0.0),
123147
{"dtype": torch.int32},
124148
)
125149

150+
if "val" in weight_node.meta:
151+
fake_mode = weight_node.meta["val"].fake_mode
152+
if fake_mode is not None:
153+
with fake_mode:
154+
fake_bias = torch.zeros([bias_shape], dtype=torch.int32)
155+
new_node.meta["val"] = fake_bias
156+
else:
157+
new_node.meta["val"] = torch.zeros([bias_shape], dtype=torch.int32)
158+
copy_node_metadata(new_node, weight_node)
159+
160+
return new_node
161+
126162

127163
def get_bias_qparams(
128164
obs_or_fqs: List[ObserverOrFakeQuantize],

util/activation_memory_profiler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ def _get_module_hierarchy(node: torch.fx.Node) -> str:
4141
Get the module hierarchy of the given node.
4242
"""
4343
module_stack = node.meta.get("nn_module_stack")
44-
if module_stack is not None:
44+
if module_stack is not None and module_stack:
4545
module_values_list = list(module_stack.values())
46-
return module_values_list[-1][0]
46+
if module_values_list:
47+
return module_values_list[-1][0]
4748
return ""
4849

4950

0 commit comments

Comments
 (0)