@@ -211,7 +211,7 @@ def forward(
211211 # Combine values with the concept embeddings
212212 x = self .concept_value_transformation_layer (x , concept_values , concept_value_masks )
213213 age_embeddings = self .age_embedding_layer (ages )
214- time_embeddings = self .age_embedding_layer (dates )
214+ time_embeddings = self .time_embedding_layer (dates )
215215 positional_embeddings = self .positional_embedding_layer (visit_concept_orders )
216216 x = self .linear_proj (torch .cat ([x , time_embeddings , age_embeddings , positional_embeddings ], dim = - 1 ))
217217 x = gelu_new (x )
@@ -298,21 +298,15 @@ def forward(
298298 # [batch_size, from_seq_length, to_seq_length]
299299 # ourselves in which case we just need to make it broadcastable to all heads.
300300 # The flash attention requires the original attention_mask
301- if not getattr (self .config , "_attn_implementation" , "eager" ) == "xformers" :
302- if seq_lens is not None :
303- attention_mask = create_block_diagonal_mask (seq_lens )
304- else :
305- attention_mask : torch .Tensor = self .get_extended_attention_mask (attention_mask , input_shape )
301+ if seq_lens is None :
302+ attention_mask : torch .Tensor = self .get_extended_attention_mask (attention_mask , input_shape )
306303 else :
307- if seq_lens is None :
308- raise RuntimeError (
309- f"seq_lens cannot be None when { getattr (self .config , '_attn_implementation' , 'eager' )} is used"
310- )
311304 if not _is_package_available ("xformers" ):
305+ raise RuntimeError (f"seq_lens cannot be None when xformers is installed" )
306+ if input_ids .shape [0 ] > 0 :
312307 raise RuntimeError (
313- f"xformers must be installed when { getattr ( self . config , '_attn_implementation' , 'eager' ) } is used "
308+ f"seq_lens is provided, which indicates sample packing, hence the batch_size must be one. "
314309 )
315-
316310 seq_lens_list = seq_lens .flatten ().to (torch .int ).cpu ().numpy ().tolist ()
317311 attention_mask = fmha .attn_bias .BlockDiagonalMask .from_seqlens (seq_lens_list , device = seq_lens .device )
318312
0 commit comments