diff --git a/backends/arm/test/models/test_nn_modules.py b/backends/arm/test/models/test_nn_modules.py index 0daf035a7f1..158a8a587e2 100644 --- a/backends/arm/test/models/test_nn_modules.py +++ b/backends/arm/test/models/test_nn_modules.py @@ -17,6 +17,8 @@ - Transformer """ +from typing import Callable + import torch from executorch.backends.arm.test.common import parametrize from executorch.backends.arm.test.tester.test_pipeline import ( @@ -24,25 +26,82 @@ TosaPipelineINT, ) + +def make_module_wrapper( + name: str, module_factory: Callable[[], torch.nn.Module] +) -> torch.nn.Module: + class ModuleWrapper(torch.nn.Module): + def __init__(self): + super().__init__() + self._module = module_factory() + + def forward(self, *args, **kwargs): + return self._module(*args, **kwargs) + + ModuleWrapper.__name__ = name + ModuleWrapper.__qualname__ = name + return ModuleWrapper() + + example_input = torch.rand(1, 6, 16, 16) module_tests = [ - (torch.nn.Embedding(10, 10), (torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]),)), - (torch.nn.LeakyReLU(), (example_input,)), - (torch.nn.BatchNorm1d(16), (torch.rand(6, 16, 16),)), - (torch.nn.AdaptiveAvgPool2d((12, 12)), (example_input,)), - (torch.nn.ConvTranspose2d(6, 3, 2), (example_input,)), - (torch.nn.GRU(10, 20, 2), (torch.randn(5, 3, 10), torch.randn(2, 3, 20))), - (torch.nn.GroupNorm(2, 6), (example_input,)), - (torch.nn.InstanceNorm2d(16), (example_input,)), - (torch.nn.PReLU(), (example_input,)), ( - torch.nn.Transformer( - d_model=64, - nhead=1, - num_encoder_layers=1, - num_decoder_layers=1, - dtype=torch.float32, + make_module_wrapper( + "EmbeddingModule", + lambda: torch.nn.Embedding(10, 10), + ), + (torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]),), + ), + ( + make_module_wrapper("LeakyReLUModule", torch.nn.LeakyReLU), + (example_input,), + ), + ( + make_module_wrapper("BatchNorm1dModule", lambda: torch.nn.BatchNorm1d(16)), + (torch.rand(6, 16, 16),), + ), + ( + make_module_wrapper( + "AdaptiveAvgPool2dModule", + lambda: torch.nn.AdaptiveAvgPool2d((12, 12)), + ), + (example_input,), + ), + ( + make_module_wrapper( + "ConvTranspose2dModule", lambda: torch.nn.ConvTranspose2d(6, 3, 2) + ), + (example_input,), + ), + ( + make_module_wrapper("GRUModule", lambda: torch.nn.GRU(10, 20, 2)), + (torch.randn(5, 3, 10), torch.randn(2, 3, 20)), + ), + ( + make_module_wrapper("GroupNormModule", lambda: torch.nn.GroupNorm(2, 6)), + (example_input,), + ), + ( + make_module_wrapper( + "InstanceNorm2dModule", lambda: torch.nn.InstanceNorm2d(16) + ), + (example_input,), + ), + ( + make_module_wrapper("PReLUModule", torch.nn.PReLU), + (example_input,), + ), + ( + make_module_wrapper( + "TransformerModule", + lambda: torch.nn.Transformer( + d_model=64, + nhead=1, + num_encoder_layers=1, + num_decoder_layers=1, + dtype=torch.float32, + ), ), (torch.rand((10, 32, 64)), torch.rand((20, 32, 64))), ), @@ -78,9 +137,9 @@ def test_nn_Modules_FP(test_data): "test_data", test_parameters, xfails={ - "GRU": "RuntimeError: Node aten_linear_default with op was not decomposed or delegated.", - "PReLU": "RuntimeError: mul(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.", - "Transformer": "AssertionError: Output 0 does not match reference output.", + "GRUModule": "RuntimeError: Node aten_linear_default with op was not decomposed or delegated.", + "PReLUModule": "RuntimeError: mul(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.", + "TransformerModule": "AssertionError: Output 0 does not match reference output.", }, ) def test_nn_Modules_INT(test_data):