1616from typing import Optional , Tuple , List
1717from ..utils .registry import Registry
1818from ..attention_cell import MultiHeadAttentionCell , gen_self_attn_mask , gen_mem_attn_mask
19- from ..layers import PositionalEmbedding , PositionwiseFFN , InitializerType
19+ from ..layers import PositionalEmbedding , PositionwiseFFN , InitializerType , AdapterModule
2020from ..utils .config import CfgNode as CN
2121from ..sequence_sampler import BaseStepDecoder
2222
@@ -149,7 +149,9 @@ def __init__(self,
149149 bias_initializer : Optional [InitializerType ] = 'zeros' ,
150150 activation : str = 'relu' ,
151151 dtype = 'float32' ,
152- layout = 'NT' ):
152+ layout = 'NT' ,
153+ use_adapter = False ,
154+ adapter_config = {}):
153155 """
154156
155157 Parameters
@@ -186,6 +188,8 @@ def __init__(self,
186188 self ._pre_norm = pre_norm
187189 self ._dtype = dtype
188190 self ._layout = layout
191+ self ._use_adapter = use_adapter
192+ self ._adapter_config = adapter_config
189193 assert layout in ['TN' , 'NT' ], 'Invalid layout received = {}. ' \
190194 'Only "TN" and "NT" are accepted!' .format (layout )
191195 assert self ._units % self ._num_heads == 0 , 'units must be divisive by the number of heads'
@@ -204,6 +208,9 @@ def __init__(self,
204208 weight_initializer = weight_initializer ,
205209 bias_initializer = bias_initializer ,
206210 dtype = self ._dtype )
211+
212+ if self ._use_adapter :
213+ self .adapter_layer_attn = AdapterModule (in_units = units , adapter_config = adapter_config )
207214 attention_layout = 'NTK' if self ._layout == 'NT' else 'TNK'
208215 self .attention_cell = \
209216 MultiHeadAttentionCell (
@@ -225,7 +232,9 @@ def __init__(self,
225232 layer_norm_eps = layer_norm_eps ,
226233 activation = activation ,
227234 pre_norm = pre_norm ,
228- dtype = self ._dtype )
235+ dtype = self ._dtype ,
236+ use_adapter = self ._use_adapter ,
237+ adapter_config = self ._adapter_config )
229238
230239 @property
231240 def layout (self ) -> str :
@@ -265,6 +274,8 @@ def forward(self, data, attn_mask):
265274 out , [_ , attn_weight ] = self .attention_cell (query , key , value , attn_mask )
266275 out = self .attention_proj (out )
267276 out = self .dropout_layer (out )
277+ if self ._use_adapter :
278+ out = self .adapter_layer_attn (out )
268279 out = out + data
269280 if not self ._pre_norm :
270281 out = self .layer_norm (out )
0 commit comments