Skip to content

Commit 5fab6a5

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Correct type annotations in aot_arm_compiler
- Correct/add type annotation in aot_arm_compiler.py - Remove one redundant variable assignment (dead code) Signed-off-by: Martin Lindström <[email protected]> Change-Id: I7a129c53c25b991e82b12904bd4d714d0937403a
1 parent 4622edb commit 5fab6a5

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

examples/arm/aot_arm_compiler.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262

6363
from executorch.extension.export_util.utils import save_pte_program
6464
from tabulate import tabulate
65+
from torch.export import ExportedProgram
66+
from torch.fx import GraphModule
6567
from torch.utils.data import DataLoader
6668

6769
# Quantize model if required using the standard export quantizaion flow.
@@ -146,13 +148,13 @@ def get_model_and_inputs_from_name(
146148

147149

148150
def quantize(
149-
model: torch.nn.Module,
151+
model: GraphModule,
150152
model_name: str,
151153
compile_specs: EthosUCompileSpec | VgfCompileSpec | TosaCompileSpec,
152154
example_inputs: Tuple[torch.Tensor],
153155
evaluator_name: str | None,
154156
evaluator_config: Dict[str, Any] | None,
155-
) -> torch.nn.Module:
157+
) -> GraphModule:
156158
"""This is the official recommended flow for quantization in pytorch 2.0
157159
export"""
158160
logging.info("Quantizing Model...")
@@ -741,7 +743,12 @@ def save_bpte_program(exec_prog, original_model: torch.nn.Module, output_name: s
741743
save_bundled_program(exec_prog, method_test_suites, output_name)
742744

743745

744-
def quantize_model(args, model: torch.nn.Module, example_inputs, compile_spec):
746+
def quantize_model(
747+
args,
748+
model: GraphModule,
749+
example_inputs: Tuple[torch.Tensor],
750+
compile_spec,
751+
) -> Tuple[GraphModule, ExportedProgram]:
745752
model_int8 = quantize(
746753
model,
747754
args.model_name,
@@ -759,7 +766,10 @@ def quantize_model(args, model: torch.nn.Module, example_inputs, compile_spec):
759766

760767

761768
def to_edge_TOSA_delegate(
762-
exported_program, args, model: torch.nn.Module, example_inputs
769+
exported_program: ExportedProgram,
770+
args,
771+
model: GraphModule,
772+
example_inputs: Tuple[torch.Tensor],
763773
):
764774
# As we can target multiple output encodings, one must
765775
# be specified.
@@ -778,7 +788,6 @@ def to_edge_TOSA_delegate(
778788
model_int8, exported_program = quantize_model(
779789
args, model, example_inputs, compile_spec
780790
)
781-
model = model_int8
782791

783792
if isinstance(compile_spec, EthosUCompileSpec):
784793
partitioner = EthosUPartitioner(compile_spec)
@@ -800,7 +809,12 @@ def to_edge_TOSA_delegate(
800809
return model_int8, edge
801810

802811

803-
def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_inputs):
812+
def to_edge_no_delegate(
813+
exported_program: ExportedProgram,
814+
args,
815+
model: GraphModule,
816+
example_inputs: Tuple[torch.Tensor],
817+
):
804818
model_int8 = None
805819
if args.quantize:
806820
# As we can target multiple output encodings, one must

0 commit comments

Comments
 (0)