Skip to content

Commit 649f92d

Browse files
authored
Arm backend: Correct type annotations in aot_arm_compiler (#14627)
- Correct/add type annotation in aot_arm_compiler.py - Remove one redundant variable assignment (dead code) Signed-off-by: Martin Lindström <[email protected]>
1 parent a4ac70d commit 649f92d

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

examples/arm/aot_arm_compiler.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161

6262
from executorch.extension.export_util.utils import save_pte_program
6363
from tabulate import tabulate
64+
from torch.export import ExportedProgram
65+
from torch.fx import GraphModule
6466
from torch.utils.data import DataLoader
6567

6668
# Quantize model if required using the standard export quantizaion flow.
@@ -145,13 +147,13 @@ def get_model_and_inputs_from_name(
145147

146148

147149
def quantize(
148-
model: torch.nn.Module,
150+
model: GraphModule,
149151
model_name: str,
150152
compile_specs: EthosUCompileSpec | VgfCompileSpec | TosaCompileSpec,
151153
example_inputs: Tuple[torch.Tensor],
152154
evaluator_name: str | None,
153155
evaluator_config: Dict[str, Any] | None,
154-
) -> torch.nn.Module:
156+
) -> GraphModule:
155157
"""This is the official recommended flow for quantization in pytorch 2.0
156158
export"""
157159
logging.info("Quantizing Model...")
@@ -601,7 +603,12 @@ def save_bpte_program(exec_prog, original_model: torch.nn.Module, output_name: s
601603
save_bundled_program(exec_prog, method_test_suites, output_name)
602604

603605

604-
def quantize_model(args, model: torch.nn.Module, example_inputs, compile_spec):
606+
def quantize_model(
607+
args,
608+
model: GraphModule,
609+
example_inputs: Tuple[torch.Tensor],
610+
compile_spec,
611+
) -> Tuple[GraphModule, ExportedProgram]:
605612
model_int8 = quantize(
606613
model,
607614
args.model_name,
@@ -619,7 +626,10 @@ def quantize_model(args, model: torch.nn.Module, example_inputs, compile_spec):
619626

620627

621628
def to_edge_TOSA_delegate(
622-
exported_program, args, model: torch.nn.Module, example_inputs
629+
exported_program: ExportedProgram,
630+
args,
631+
model: GraphModule,
632+
example_inputs: Tuple[torch.Tensor],
623633
):
624634
# As we can target multiple output encodings, one must
625635
# be specified.
@@ -638,7 +648,6 @@ def to_edge_TOSA_delegate(
638648
model_int8, exported_program = quantize_model(
639649
args, model, example_inputs, compile_spec
640650
)
641-
model = model_int8
642651

643652
if isinstance(compile_spec, EthosUCompileSpec):
644653
partitioner = EthosUPartitioner(compile_spec)
@@ -660,7 +669,12 @@ def to_edge_TOSA_delegate(
660669
return model_int8, edge
661670

662671

663-
def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_inputs):
672+
def to_edge_no_delegate(
673+
exported_program: ExportedProgram,
674+
args,
675+
model: GraphModule,
676+
example_inputs: Tuple[torch.Tensor],
677+
):
664678
model_int8 = None
665679
if args.quantize:
666680
# As we can target multiple output encodings, one must

0 commit comments

Comments
 (0)