@@ -378,6 +378,7 @@ def __init__(self,
378378 encoder_depth_multiplier : int = 1 ,
379379 encoder_fc_factor : float = 0.5 ,
380380 encoder_dropout : float = 0 ,
381+ encoder_trainable : bool = True ,
381382 prediction_embed_dim : int = 512 ,
382383 prediction_embed_dropout : int = 0 ,
383384 prediction_num_rnns : int = 1 ,
@@ -386,13 +387,15 @@ def __init__(self,
386387 prediction_rnn_implementation : int = 2 ,
387388 prediction_layer_norm : bool = True ,
388389 prediction_projection_units : int = 0 ,
390+ prediction_trainable : bool = True ,
389391 joint_dim : int = 1024 ,
390392 joint_activation : str = "tanh" ,
391393 prejoint_linear : bool = True ,
392394 joint_mode : str = "add" ,
395+ joint_trainable : bool = True ,
393396 kernel_regularizer = L2 ,
394397 bias_regularizer = L2 ,
395- name : str = "conformer_transducer " ,
398+ name : str = "conformer " ,
396399 ** kwargs ):
397400 super (Conformer , self ).__init__ (
398401 encoder = ConformerEncoder (
@@ -408,7 +411,9 @@ def __init__(self,
408411 fc_factor = encoder_fc_factor ,
409412 dropout = encoder_dropout ,
410413 kernel_regularizer = kernel_regularizer ,
411- bias_regularizer = bias_regularizer
414+ bias_regularizer = bias_regularizer ,
415+ trainable = encoder_trainable ,
416+ name = f"{ name } _encoder"
412417 ),
413418 vocabulary_size = vocabulary_size ,
414419 embed_dim = prediction_embed_dim ,
@@ -419,13 +424,16 @@ def __init__(self,
419424 rnn_implementation = prediction_rnn_implementation ,
420425 layer_norm = prediction_layer_norm ,
421426 projection_units = prediction_projection_units ,
427+ prediction_trainable = prediction_trainable ,
422428 joint_dim = joint_dim ,
423429 joint_activation = joint_activation ,
424430 prejoint_linear = prejoint_linear ,
425431 joint_mode = joint_mode ,
432+ joint_trainable = joint_trainable ,
426433 kernel_regularizer = kernel_regularizer ,
427434 bias_regularizer = bias_regularizer ,
428- name = name , ** kwargs
435+ name = name ,
436+ ** kwargs
429437 )
430438 self .dmodel = encoder_dmodel
431439 self .time_reduction_factor = self .encoder .conv_subsampling .time_reduction_factor
0 commit comments