1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ import logging
8+
9+ import torch
10+ import torch .nn as nn
11+
12+ from ..model_base import EagerModelBase
13+
14+
15+ class BidirectionalLSTM (nn .Module ):
16+ """Bidirectional LSTM for sequence modeling"""
17+
18+ def __init__ (self , input_size = 100 , hidden_size = 128 , num_layers = 2 , num_classes = 10 ):
19+ super (BidirectionalLSTM , self ).__init__ ()
20+ self .hidden_size = hidden_size
21+ self .num_layers = num_layers
22+
23+ # Bidirectional LSTM
24+ self .lstm = nn .LSTM (
25+ input_size ,
26+ hidden_size ,
27+ num_layers ,
28+ batch_first = True ,
29+ bidirectional = True
30+ )
31+
32+ # Output layer (hidden_size * 2 because of bidirectional)
33+ self .fc = nn .Linear (hidden_size * 2 , num_classes )
34+
35+ def forward (self , x ):
36+ # Initialize hidden states
37+ # For bidirectional: hidden states shape is (num_layers * 2, batch, hidden_size)
38+ h0 = torch .zeros (self .num_layers * 2 , x .size (0 ), self .hidden_size ).to (x .device )
39+ c0 = torch .zeros (self .num_layers * 2 , x .size (0 ), self .hidden_size ).to (x .device )
40+
41+ # LSTM forward pass
42+ out , _ = self .lstm (x , (h0 , c0 ))
43+
44+ # Take the last time step output
45+ out = self .fc (out [:, - 1 , :])
46+ return out
47+
48+
49+ class BidirectionalLSTMTextClassifier (nn .Module ):
50+ """Bidirectional LSTM for text classification with embedding layer"""
51+
52+ def __init__ (self , vocab_size = 10000 , embedding_dim = 128 , hidden_size = 256 , num_classes = 2 ):
53+ super (BidirectionalLSTMTextClassifier , self ).__init__ ()
54+ self .hidden_size = hidden_size
55+
56+ # Embedding layer
57+ self .embedding = nn .Embedding (vocab_size , embedding_dim )
58+
59+ # Bidirectional LSTM
60+ self .lstm = nn .LSTM (
61+ embedding_dim ,
62+ hidden_size ,
63+ bidirectional = True ,
64+ batch_first = True
65+ )
66+
67+ # Output layer
68+ self .fc = nn .Linear (hidden_size * 2 , num_classes )
69+
70+ def forward (self , x ):
71+ # Embedding
72+ embedded = self .embedding (x )
73+
74+ # LSTM
75+ lstm_out , _ = self .lstm (embedded )
76+
77+ # Global max pooling over sequence dimension
78+ pooled = torch .max (lstm_out , dim = 1 )[0 ]
79+
80+ # Classification
81+ output = self .fc (pooled )
82+ return output
83+
84+
85+ class BidirectionalLSTMModel (EagerModelBase ):
86+ def __init__ (self ):
87+ pass
88+
89+ def get_eager_model (self ) -> torch .nn .Module :
90+ logging .info ("Loading Bidirectional LSTM model" )
91+ model = BidirectionalLSTM (
92+ input_size = 100 ,
93+ hidden_size = 128 ,
94+ num_layers = 2 ,
95+ num_classes = 10
96+ )
97+ model .eval ()
98+ logging .info ("Loaded Bidirectional LSTM model" )
99+ return model
100+
101+ def get_example_inputs (self ):
102+ # Example: (batch_size=1, seq_len=50, input_size=100)
103+ tensor_size = (1 , 50 , 100 )
104+ return (torch .randn (tensor_size ),)
105+
106+
107+ class BidirectionalLSTMTextModel (EagerModelBase ):
108+ def __init__ (self ):
109+ pass
110+
111+ def get_eager_model (self ) -> torch .nn .Module :
112+ logging .info ("Loading Bidirectional LSTM text classifier" )
113+ model = BidirectionalLSTMTextClassifier (
114+ vocab_size = 10000 ,
115+ embedding_dim = 128 ,
116+ hidden_size = 256 ,
117+ num_classes = 2
118+ )
119+ model .eval ()
120+ logging .info ("Loaded Bidirectional LSTM text classifier" )
121+ return model
122+
123+ def get_example_inputs (self ):
124+ # Example: (batch_size=1, seq_len=100) - token indices
125+ tensor_size = (1 , 100 )
126+ return (torch .randint (0 , 10000 , tensor_size ),)
0 commit comments