@@ -23,15 +23,17 @@ class TransformerDecoderLayer(nn.Module):
2323 """
2424
2525 def __init__ (self , d_model , heads , d_ff , dropout ,
26- self_attn_type = "scaled-dot" , max_relative_positions = 0 ):
26+ self_attn_type = "scaled-dot" , max_relative_positions = 0 ,
27+ aan_useffn = False ):
2728 super (TransformerDecoderLayer , self ).__init__ ()
2829
2930 if self_attn_type == "scaled-dot" :
3031 self .self_attn = MultiHeadedAttention (
3132 heads , d_model , dropout = dropout ,
3233 max_relative_positions = max_relative_positions )
3334 elif self_attn_type == "average" :
34- self .self_attn = AverageAttention (d_model , dropout = dropout )
35+ self .self_attn = AverageAttention (d_model , dropout = dropout ,
36+ aan_useffn = aan_useffn )
3537
3638 self .context_attn = MultiHeadedAttention (
3739 heads , d_model , dropout = dropout )
@@ -72,7 +74,7 @@ def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
7274 query , attn = self .self_attn (input_norm , input_norm , input_norm ,
7375 mask = dec_mask ,
7476 layer_cache = layer_cache ,
75- type = "self" )
77+ attn_type = "self" )
7678 elif isinstance (self .self_attn , AverageAttention ):
7779 query , attn = self .self_attn (input_norm , mask = dec_mask ,
7880 layer_cache = layer_cache , step = step )
@@ -83,7 +85,7 @@ def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
8385 mid , attn = self .context_attn (memory_bank , memory_bank , query_norm ,
8486 mask = src_pad_mask ,
8587 layer_cache = layer_cache ,
86- type = "context" )
88+ attn_type = "context" )
8789 output = self .feed_forward (self .drop (mid ) + query )
8890
8991 return output , attn
@@ -127,7 +129,7 @@ class TransformerDecoder(DecoderBase):
127129
128130 def __init__ (self , num_layers , d_model , heads , d_ff ,
129131 copy_attn , self_attn_type , dropout , embeddings ,
130- max_relative_positions ):
132+ max_relative_positions , aan_useffn ):
131133 super (TransformerDecoder , self ).__init__ ()
132134
133135 self .embeddings = embeddings
@@ -138,7 +140,8 @@ def __init__(self, num_layers, d_model, heads, d_ff,
138140 self .transformer_layers = nn .ModuleList (
139141 [TransformerDecoderLayer (d_model , heads , d_ff , dropout ,
140142 self_attn_type = self_attn_type ,
141- max_relative_positions = max_relative_positions )
143+ max_relative_positions = max_relative_positions ,
144+ aan_useffn = aan_useffn )
142145 for i in range (num_layers )])
143146
144147 # previously, there was a GlobalAttention module here for copy
@@ -159,7 +162,8 @@ def from_opt(cls, opt, embeddings):
159162 opt .self_attn_type ,
160163 opt .dropout [0 ] if type (opt .dropout ) is list else opt .dropout ,
161164 embeddings ,
162- opt .max_relative_positions )
165+ opt .max_relative_positions ,
166+ opt .aan_useffn )
163167
164168 def init_state (self , src , memory_bank , enc_hidden ):
165169 """Initialize decoder state."""
@@ -233,7 +237,8 @@ def _init_cache(self, memory_bank):
233237 for i , layer in enumerate (self .transformer_layers ):
234238 layer_cache = {"memory_keys" : None , "memory_values" : None }
235239 if isinstance (layer .self_attn , AverageAttention ):
236- layer_cache ["prev_g" ] = torch .zeros ((batch_size , 1 , depth ))
240+ layer_cache ["prev_g" ] = torch .zeros ((batch_size , 1 , depth ),
241+ device = memory_bank .device )
237242 else :
238243 layer_cache ["self_keys" ] = None
239244 layer_cache ["self_values" ] = None
0 commit comments