File tree Expand file tree Collapse file tree 3 files changed +45
-0
lines changed Expand file tree Collapse file tree 3 files changed +45
-0
lines changed Original file line number Diff line number Diff line change 1
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
2
2
# All rights reserved.
3
+ # Copyright 2024 Arm Limited and/or its affiliates.
3
4
#
4
5
# This source code is licensed under the BSD-style license found in the
5
6
# LICENSE file in the root directory of this source tree.
16
17
"emformer_predict" : ("emformer_rnnt" , "EmformerRnntPredictorModel" ),
17
18
"emformer_join" : ("emformer_rnnt" , "EmformerRnntJoinerModel" ),
18
19
"llama2" : ("llama2" , "Llama2Model" ),
20
+ "lstm" : ("lstm" , "LSTMModel" ),
19
21
"mobilebert" : ("mobilebert" , "MobileBertModelExample" ),
20
22
"mv2" : ("mobilenet_v2" , "MV2Model" ),
21
23
"mv2_untrained" : ("mobilenet_v2" , "MV2UntrainedModel" ),
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ # Copyright 2024 Arm Limited and/or its affiliates.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from .model import LSTMModel
9
+
10
+ __all__ = [
11
+ LSTMModel ,
12
+ ]
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ # Copyright 2024 Arm Limited and/or its affiliates.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import logging
9
+
10
+ import torch
11
+
12
+ from torch .nn .quantizable .modules import rnn
13
+
14
+ from ..model_base import EagerModelBase
15
+
16
+
17
+ class LSTMModel (EagerModelBase ):
18
+ def __init__ (self ):
19
+ pass
20
+
21
+ def get_eager_model (self ) -> torch .nn .Module :
22
+ logging .info ("Loading LSTM model" )
23
+ lstm = rnn .LSTM (10 , 20 , 2 )
24
+ logging .info ("Loaded LSTM model" )
25
+ return lstm
26
+
27
+ def get_example_inputs (self ):
28
+ input_tensor = torch .randn (5 , 3 , 10 )
29
+ h0 = torch .randn (2 , 3 , 20 )
30
+ c0 = torch .randn (2 , 3 , 20 )
31
+ return (input_tensor , (h0 , c0 ))
You can’t perform that action at this time.
0 commit comments