Skip to content

Commit 771fa2c

Browse files
authored
Merge branch 'main' into export-D82602663
2 parents 2f3264c + 44972ad commit 771fa2c

File tree

5 files changed

+144
-100
lines changed

5 files changed

+144
-100
lines changed

.github/workflows/add-unanswered-to-project.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
2121
// List of authors to exclude
2222
const excludedAuthors = new Set([
23-
"nil-is-all", "cbilgin", "KimishPatel", "psiddh", "digantdesai", "SS-JIA", "ahmtox", "mcr229", "shoumikhin",
23+
"nil-is-all", "cbilgin", "kimishpatel", "psiddh", "digantdesai", "SS-JIA", "ahmtox", "mcr229", "shoumikhin",
2424
"manuelcandales", "metascroy", "cccclai", "rohansjoshi", "kirklandsign", "abhinaykukkadapu", "JacobSzwejbka",
2525
"Conarnar", "lucylq", "larryliu0820", "BujSet", "Gasoonjia", "Juntian777", "guangy10", "jackzhxng",
2626
"GregoryComer", "leafs1", "swolchok", "mergennachin", "tarun292", "byjlw", "jathu", "Jack-Khuu", "georgehong",

.github/workflows/docker-builds.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
strategy:
3232
fail-fast: false
3333
matrix:
34-
runner: [linux.2xlarge]
34+
runner: [linux.4xlarge]
3535
docker-image-name: [
3636
executorch-ubuntu-22.04-gcc9,
3737
executorch-ubuntu-22.04-clang12,

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
471471
pattern.partition_types(),
472472
)
473473
for fused_partition in fused_partitions:
474-
anchors = pattern.get_anchors(graph_module, fused_partition)
474+
anchors, op_node = pattern.get_anchors(graph_module, fused_partition)
475475
if not anchors or anchors.empty:
476476
continue
477477
if any(self.is_fused(p.nodes) for p in fused_partition):
@@ -512,13 +512,10 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
512512
bias_inputs = [node.args[0] for node in dequants_biases]
513513
other_inputs = [node.args[idx] for node, idx in anchors.others]
514514

515-
# The node is the first index of the list and first of the tuple
516-
anchor_output_node = anchors.output[0][0]
515+
assert op_node is not None, "op_node is None"
516+
quant_node = list(op_node.users.keys())[0]
517517

518-
assert len(anchor_output_node.users) == 1
519-
quant_node = list(anchor_output_node.users.keys())[0]
520-
521-
with graph_module.graph.inserting_after(anchor_output_node):
518+
with graph_module.graph.inserting_after(op_node):
522519
args = tuple(
523520
inputs_inputs + weights_inputs + other_inputs + bias_inputs
524521
)
@@ -532,7 +529,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
532529
)
533530
elif isinstance(pattern, CatPattern):
534531
args, kwargs = get_args_and_kwargs_cat(
535-
inputs_inputs, other_inputs, anchor_output_node
532+
inputs_inputs, other_inputs, op_node
536533
)
537534
elif isinstance(pattern, ConvReluPatterns):
538535
# For ConvReLU, we are fusing Conv+ReLU
@@ -563,7 +560,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
563560
dequants_weights,
564561
bias_inputs,
565562
quant_node,
566-
anchor_output_node,
563+
op_node,
567564
)
568565
elif isinstance(pattern, LinearPattern):
569566
args, kwargs = get_args_and_kwargs_linear(
@@ -618,20 +615,28 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
618615
inputs_inputs,
619616
dequants_inputs,
620617
quant_node,
621-
anchor_output_node,
618+
op_node,
622619
)
620+
623621
fused = graph_module.graph.call_function(
624622
pattern.replacement_op(),
625623
args,
626624
kwargs,
627625
)
628-
fused.meta = quant_node.meta
629-
quant_node.replace_all_uses_with(fused)
626+
627+
if len(anchors.output) > 0:
628+
fused.meta = quant_node.meta
629+
quant_node.replace_all_uses_with(fused)
630+
else:
631+
fused.meta = op_node.meta
632+
op_node.replace_all_uses_with(fused)
633+
if op_node.op == "output":
634+
_ = graph_module.graph.output((fused,))
630635

631636
legalize_graph(graph_module)
632637
graph_module.graph.eliminate_dead_code()
633-
# pyre-fixme[7]: Incompatible return type
634638
graph_module.recompile()
639+
return PassResult(graph_module, True)
635640

636641
@classmethod
637642
# pyre-ignore[2]: Parameter `nodes` has no type specified

0 commit comments

Comments
 (0)