diff --git a/.mypy.ini b/.mypy.ini index e01392a0dfd..a8df0f41b00 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -24,7 +24,7 @@ files = test, util -mypy_path = executorch +mypy_path = executorch,src [mypy-executorch.backends.*] follow_untyped_imports = True diff --git a/backends/arm/test/passes/test_fuse_duplicate_users_pass.py b/backends/arm/test/passes/test_fuse_duplicate_users_pass.py index a7e80794015..ffe56e72691 100644 --- a/backends/arm/test/passes/test_fuse_duplicate_users_pass.py +++ b/backends/arm/test/passes/test_fuse_duplicate_users_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Dict, Tuple import torch from executorch.backends.arm._passes import FuseDuplicateUsersPass @@ -13,7 +13,12 @@ input_t = Tuple[torch.Tensor] # Input x -class FuseaAvgPool(torch.nn.Module): +class ModuleWithOps(torch.nn.Module): + ops_before_pass: Dict[str, int] + ops_after_pass: Dict[str, int] + + +class FuseaAvgPool(ModuleWithOps): ops_before_pass = { "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 3, } @@ -27,7 +32,7 @@ def forward(self, x): return self.avg(x) + self.avg(x) + self.avg(x) -class FuseAvgPoolChain(torch.nn.Module): +class FuseAvgPoolChain(ModuleWithOps): ops_before_pass = { "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 6, } @@ -44,14 +49,14 @@ def forward(self, x): return first + second + third -modules = { +modules: Dict[str, ModuleWithOps] = { "fuse_avg_pool": FuseaAvgPool(), "fuse_avg_pool_chain": FuseAvgPoolChain(), } @common.parametrize("module", modules) -def test_fuse_duplicate_ops_FP(module: torch.nn.Module): +def test_fuse_duplicate_ops_FP(module: ModuleWithOps): pipeline = PassPipeline[input_t]( module=module, test_data=(torch.ones(1, 1, 1, 1),),