1616 chunk ,
1717 concat ,
1818 constant ,
19- expand ,
2019 expand_dims ,
2120 expand_dims_like ,
2221 expand_mask ,
@@ -95,15 +94,24 @@ def __init__(self, dim, kernel_size=31, groups=16):
9594 self .conv1d2 = Conv1d (dim , dim , kernel_size , groups = groups , padding = kernel_size // 2 )
9695 self .mish = Mish ()
9796
98- def forward (self , x , mask = None ): # noqa: F722
97+ def forward (self , x , mask = None ):
9998 if default_net ().plugin_config .remove_input_padding :
10099 x = unsqueeze (x , 0 )
101- x = permute (x , [0 , 2 , 1 ])
102- x = self .mish (self .conv1d2 (self .mish (self .conv1d1 (x ))))
103- out = permute (x , [0 , 2 , 1 ])
100+ if mask is not None :
101+ mask = mask .view (concat ([shape (mask , 0 ), 1 , shape (mask , 1 )])) # [B 1 N]
102+ mask = expand_dims_like (mask , x ) # [B D N]
103+ mask = cast (mask , x .dtype )
104+ x = permute (x , [0 , 2 , 1 ]) # [B D N]
105+
106+ if mask is not None :
107+ x = self .mish (self .conv1d2 (self .mish (self .conv1d1 (x * mask ) * mask )) * mask )
108+ else :
109+ x = self .mish (self .conv1d2 (self .mish (self .conv1d1 (x ))))
110+
111+ x = permute (x , [0 , 2 , 1 ]) # [B N D]
104112 if default_net ().plugin_config .remove_input_padding :
105- out = squeeze (out , 0 )
106- return out
113+ x = squeeze (x , 0 )
114+ return x
107115
108116
109117class Attention (Module ):
@@ -185,6 +193,7 @@ def forward(
185193 rope_cos ,
186194 rope_sin ,
187195 input_lengths ,
196+ mask = None ,
188197 c = None , # context c
189198 scale = 1.0 ,
190199 rope = None ,
@@ -283,6 +292,7 @@ def __call__(
283292 input_lengths ,
284293 scale = 1.0 ,
285294 rope = None ,
295+ mask = None ,
286296 ) -> torch .FloatTensor :
287297 query = attn .to_q (x )
288298 key = attn .to_k (x )
@@ -295,20 +305,8 @@ def __call__(
295305 inner_dim = key .shape [- 1 ]
296306 norm_factor = math .sqrt (attn .attention_head_size )
297307 q_scaling = 1.0 / norm_factor
298- mask = None
299- if not default_net ().plugin_config .remove_input_padding :
300- N = shape (x , 1 )
301- B = shape (x , 0 )
302- seq_len_2d = concat ([1 , N ])
303- max_position_embeddings = 4096
304- # create position ids
305- position_ids_buffer = constant (np .expand_dims (np .arange (max_position_embeddings ).astype (np .int32 ), 0 ))
306- tmp_position_ids = slice (position_ids_buffer , starts = [0 , 0 ], sizes = seq_len_2d )
307- tmp_position_ids = expand (tmp_position_ids , concat ([B , N ])) # BxL
308- tmp_input_lengths = unsqueeze (input_lengths , 1 ) # Bx1
309- tmp_input_lengths = expand (tmp_input_lengths , concat ([B , N ])) # BxL
310- mask = tmp_position_ids < tmp_input_lengths # BxL
311- mask = mask .cast ("int32" )
308+ if default_net ().plugin_config .remove_input_padding :
309+ mask = None
312310
313311 if default_net ().plugin_config .bert_attention_plugin :
314312 qkv = concat ([query , key , value ], dim = - 1 )
@@ -393,14 +391,15 @@ def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1, pe_attn_head=No
393391 self .ff = FeedForward (dim = dim , mult = ff_mult , dropout = dropout )
394392
395393 def forward (
396- self , x , t , rope_cos , rope_sin , input_lengths , scale = 1.0 , rope = ModuleNotFoundError
394+ self , x , t , rope_cos , rope_sin , input_lengths , scale = 1.0 , rope = ModuleNotFoundError , mask = None
397395 ): # x: noised input, t: time embedding
398396 # pre-norm & modulation for attention input
399397 norm , gate_msa , shift_mlp , scale_mlp , gate_mlp = self .attn_norm (x , emb = t )
400398 # attention
401399 # norm ----> (2,1226,1024)
402- attn_output = self .attn (x = norm , rope_cos = rope_cos , rope_sin = rope_sin , input_lengths = input_lengths , scale = scale )
403-
400+ attn_output = self .attn (
401+ x = norm , rope_cos = rope_cos , rope_sin = rope_sin , input_lengths = input_lengths , scale = scale , mask = mask
402+ )
404403 # process attention output for input x
405404 if default_net ().plugin_config .remove_input_padding :
406405 x = x + gate_msa * attn_output
0 commit comments