Skip to content

Commit ed9f50f

Browse files
SaoirseARMfacebook-github-bot
authored andcommitted
Adding LSTM to example models (#5931)
Summary: Addition of LSTM model to example models. Thanks! Pull Request resolved: #5931 Reviewed By: mergennachin Differential Revision: D64047539 Pulled By: digantdesai fbshipit-source-id: 6e0d879027f2b76c482a92b82e05458814efbfeb
1 parent ec0fcee commit ed9f50f

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

examples/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2024 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -16,6 +17,7 @@
1617
"emformer_predict": ("emformer_rnnt", "EmformerRnntPredictorModel"),
1718
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),
1819
"llama2": ("llama2", "Llama2Model"),
20+
"lstm": ("lstm", "LSTMModel"),
1921
"mobilebert": ("mobilebert", "MobileBertModelExample"),
2022
"mv2": ("mobilenet_v2", "MV2Model"),
2123
"mv2_untrained": ("mobilenet_v2", "MV2UntrainedModel"),

examples/models/lstm/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2024 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from .model import LSTMModel
9+
10+
__all__ = [
11+
LSTMModel,
12+
]

examples/models/lstm/model.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2024 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import logging
9+
10+
import torch
11+
12+
from torch.nn.quantizable.modules import rnn
13+
14+
from ..model_base import EagerModelBase
15+
16+
17+
class LSTMModel(EagerModelBase):
18+
def __init__(self):
19+
pass
20+
21+
def get_eager_model(self) -> torch.nn.Module:
22+
logging.info("Loading LSTM model")
23+
lstm = rnn.LSTM(10, 20, 2)
24+
logging.info("Loaded LSTM model")
25+
return lstm
26+
27+
def get_example_inputs(self):
28+
input_tensor = torch.randn(5, 3, 10)
29+
h0 = torch.randn(2, 3, 20)
30+
c0 = torch.randn(2, 3, 20)
31+
return (input_tensor, (h0, c0))

0 commit comments

Comments
 (0)