1616import collections
1717import tensorflow as tf
1818
19- from .layers .merge_two_last_dims import Merge2LastDims
2019from .layers .subsampling import TimeReduction
2120from .transducer import Transducer , BeamHypothesis
22- from ..utils .utils import get_rnn , get_shape_invariants
21+ from ..utils .utils import get_rnn , get_shape_invariants , merge_two_last_dims
2322
2423Hypothesis = collections .namedtuple (
2524 "Hypothesis" ,
2625 ("index" , "prediction" , "encoder_states" , "prediction_states" )
2726)
2827
2928
29+ class Reshape (tf .keras .layers .Layer ):
30+ def call (self , inputs ): return merge_two_last_dims (inputs )
31+
32+
3033class StreamingTransducerBlock (tf .keras .Model ):
3134 def __init__ (self ,
32- reduction_factor : int = 3 ,
33- apply_reduction : bool = False ,
34- encoder_dim : int = 320 ,
35- encoder_type : str = "lstm" ,
36- encoder_units : int = 1024 ,
37- encoder_layer_norm : bool = True ,
38- apply_projection : bool = True ,
35+ reduction_factor : int = 0 ,
36+ dmodel : int = 640 ,
37+ rnn_type : str = "lstm" ,
38+ rnn_units : int = 2048 ,
39+ layer_norm : bool = True ,
3940 kernel_regularizer = None ,
4041 bias_regularizer = None ,
4142 ** kwargs ):
4243 super (StreamingTransducerBlock , self ).__init__ (** kwargs )
4344
44- if apply_reduction :
45+ if reduction_factor > 0 :
4546 self .reduction = TimeReduction (reduction_factor , name = f"{ self .name } _reduction" )
4647 else :
4748 self .reduction = None
4849
49- RNN = get_rnn (encoder_type )
50+ RNN = get_rnn (rnn_type )
5051 self .rnn = RNN (
51- units = encoder_units , return_sequences = True ,
52+ units = rnn_units , return_sequences = True ,
5253 name = f"{ self .name } _rnn" , return_state = True ,
5354 kernel_regularizer = kernel_regularizer ,
5455 bias_regularizer = bias_regularizer
5556 )
5657
57- if encoder_layer_norm :
58+ if layer_norm :
5859 self .ln = tf .keras .layers .LayerNormalization (name = f"{ self .name } _ln" )
5960 else :
6061 self .ln = None
6162
62- if apply_projection :
63- self .projection = tf .keras .layers .Dense (
64- encoder_dim , name = f"{ self .name } _projection" ,
65- kernel_regularizer = kernel_regularizer ,
66- bias_regularizer = bias_regularizer
67- )
68- else :
69- self .projection = None
63+ self .projection = tf .keras .layers .Dense (
64+ dmodel , name = f"{ self .name } _projection" ,
65+ kernel_regularizer = kernel_regularizer ,
66+ bias_regularizer = bias_regularizer
67+ )
7068
7169 def call (self , inputs , training = False ):
7270 outputs = inputs
@@ -76,8 +74,7 @@ def call(self, inputs, training=False):
7674 outputs = outputs [0 ]
7775 if self .ln is not None :
7876 outputs = self .ln (outputs , training = training )
79- if self .projection is not None :
80- outputs = self .projection (outputs , training = training )
77+ outputs = self .projection (outputs , training = training )
8178 return outputs
8279
8380 def recognize (self , inputs , states ):
@@ -89,12 +86,11 @@ def recognize(self, inputs, states):
8986 outputs = outputs [0 ]
9087 if self .ln is not None :
9188 outputs = self .ln (outputs , training = False )
92- if self .projection is not None :
93- outputs = self .projection (outputs , training = False )
89+ outputs = self .projection (outputs , training = False )
9490 return outputs , new_states
9591
9692 def get_config (self ):
97- conf = super ( StreamingTransducerBlock , self ). get_config ()
93+ conf = {}
9894 if self .reduction is not None :
9995 conf .update (self .reduction .get_config ())
10096 conf .update (self .rnn .get_config ())
@@ -106,38 +102,36 @@ def get_config(self):
106102
107103class StreamingTransducerEncoder (tf .keras .Model ):
108104 def __init__ (self ,
109- reduction_factor : int = 3 ,
110- reduction_positions : list = [1 ],
111- encoder_dim : int = 320 ,
112- encoder_layers : int = 8 ,
113- encoder_type : str = "lstm" ,
114- encoder_units : int = 1024 ,
115- encoder_layer_norm : bool = True ,
105+ reductions : dict = {0 : 3 , 1 : 2 },
106+ dmodel : int = 640 ,
107+ nlayers : int = 8 ,
108+ rnn_type : str = "lstm" ,
109+ rnn_units : int = 2048 ,
110+ layer_norm : bool = True ,
116111 kernel_regularizer = None ,
117112 bias_regularizer = None ,
118113 ** kwargs ):
119114 super (StreamingTransducerEncoder , self ).__init__ (** kwargs )
120115
121- self .merge = Merge2LastDims (name = f"{ self .name } _merge " )
116+ self .reshape = Reshape (name = f"{ self .name } _reshape " )
122117
123118 self .blocks = [
124119 StreamingTransducerBlock (
125- reduction_factor = reduction_factor ,
126- apply_reduction = (i in reduction_positions ),
127- apply_projection = (i != encoder_layers - 1 ),
128- encoder_dim = encoder_dim ,
129- encoder_type = encoder_type ,
130- encoder_units = encoder_units ,
131- encoder_layer_norm = encoder_layer_norm ,
120+ reduction_factor = reductions .get (i , 0 ), # key is index, value is the factor
121+ dmodel = dmodel ,
122+ rnn_type = rnn_type ,
123+ rnn_units = rnn_units ,
124+ layer_norm = layer_norm ,
132125 kernel_regularizer = kernel_regularizer ,
133126 bias_regularizer = bias_regularizer ,
134- name = f"{ self .name } _ { i } "
135- ) for i in range (encoder_layers )
127+ name = f"{ self .name } _block_ { i } "
128+ ) for i in range (nlayers )
136129 ]
137130
138131 self .time_reduction_factor = 1
139- for i in range (encoder_layers ):
140- if i in reduction_positions : self .time_reduction_factor *= reduction_factor
132+ for i in range (nlayers ):
133+ reduction_factor = reductions .get (i , 0 )
134+ if reduction_factor > 0 : self .time_reduction_factor *= reduction_factor
141135
142136 def get_initial_state (self ):
143137 """Get zeros states
@@ -157,7 +151,7 @@ def get_initial_state(self):
157151 return tf .stack (states , axis = 0 )
158152
159153 def call (self , inputs , training = False ):
160- outputs = self .merge (inputs )
154+ outputs = self .reshape (inputs )
161155 for block in self .blocks :
162156 outputs = block (outputs , training = training )
163157 return outputs
@@ -173,60 +167,60 @@ def recognize(self, inputs, states):
173167 tf.Tensor: outputs with shape [1, T, E]
174168 tf.Tensor: new states with shape [num_lstms, 1 or 2, 1, P]
175169 """
176- outputs = self .merge (inputs )
170+ outputs = self .reshape (inputs )
177171 new_states = []
178172 for i , block in enumerate (self .blocks ):
179173 outputs , block_states = block .recognize (outputs , states = tf .unstack (states [i ], axis = 0 ))
180174 new_states .append (block_states )
181175 return outputs , tf .stack (new_states , axis = 0 )
182176
183177 def get_config (self ):
184- conf = {}
178+ conf = self . reshape . get_config ()
185179 for block in self .blocks : conf .update (block .get_config ())
186180 return conf
187181
188182
189183class StreamingTransducer (Transducer ):
190184 def __init__ (self ,
191185 vocabulary_size : int ,
192- reduction_factor : int = 2 ,
193- reduction_positions : list = [1 ],
194- encoder_dim : int = 320 ,
195- encoder_layers : int = 8 ,
196- encoder_type : str = "lstm" ,
197- encoder_units : int = 1024 ,
186+ encoder_reductions : dict = {0 : 3 , 1 : 2 },
187+ encoder_dmodel : int = 640 ,
188+ encoder_nlayers : int = 8 ,
189+ encoder_rnn_type : str = "lstm" ,
190+ encoder_rnn_units : int = 2048 ,
198191 encoder_layer_norm : bool = True ,
199- embed_dim : int = 320 ,
200- embed_dropout : float = 0 ,
201- num_rnns : int = 2 ,
202- rnn_units : int = 1024 ,
203- rnn_type : str = "lstm" ,
204- layer_norm : bool = True ,
205- joint_dim : int = 320 ,
192+ prediction_embed_dim : int = 320 ,
193+ prediction_embed_dropout : float = 0 ,
194+ prediction_num_rnns : int = 2 ,
195+ prediction_rnn_units : int = 2048 ,
196+ prediction_rnn_type : str = "lstm" ,
197+ prediction_layer_norm : bool = True ,
198+ prediction_projection_units : int = 640 ,
199+ joint_dim : int = 640 ,
206200 kernel_regularizer = None ,
207201 bias_regularizer = None ,
208202 name = "StreamingTransducer" ,
209203 ** kwargs ):
210204 super (StreamingTransducer , self ).__init__ (
211205 encoder = StreamingTransducerEncoder (
212- reduction_factor = reduction_factor ,
213- reduction_positions = reduction_positions ,
214- encoder_dim = encoder_dim ,
215- encoder_layers = encoder_layers ,
216- encoder_type = encoder_type ,
217- encoder_units = encoder_units ,
218- encoder_layer_norm = encoder_layer_norm ,
206+ reductions = encoder_reductions ,
207+ dmodel = encoder_dmodel ,
208+ nlayers = encoder_nlayers ,
209+ rnn_type = encoder_rnn_type ,
210+ rnn_units = encoder_rnn_units ,
211+ layer_norm = encoder_layer_norm ,
219212 kernel_regularizer = kernel_regularizer ,
220213 bias_regularizer = bias_regularizer ,
221214 name = f"{ name } _encoder"
222215 ),
223216 vocabulary_size = vocabulary_size ,
224- embed_dim = embed_dim ,
225- embed_dropout = embed_dropout ,
226- num_rnns = num_rnns ,
227- rnn_units = rnn_units ,
228- rnn_type = rnn_type ,
229- layer_norm = layer_norm ,
217+ embed_dim = prediction_embed_dim ,
218+ embed_dropout = prediction_embed_dropout ,
219+ num_rnns = prediction_num_rnns ,
220+ rnn_units = prediction_rnn_units ,
221+ rnn_type = prediction_rnn_type ,
222+ layer_norm = prediction_layer_norm ,
223+ projection_units = prediction_projection_units ,
230224 joint_dim = joint_dim ,
231225 kernel_regularizer = kernel_regularizer ,
232226 bias_regularizer = bias_regularizer ,
0 commit comments