Skip to content

Commit 8c6cc4d

Browse files
committed
updated the logic for constructing attention_mask
1 parent cc2b08c commit 8c6cc4d

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

src/cehrbert/models/hf_models/hf_cehrbert.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def forward(
211211
# Combine values with the concept embeddings
212212
x = self.concept_value_transformation_layer(x, concept_values, concept_value_masks)
213213
age_embeddings = self.age_embedding_layer(ages)
214-
time_embeddings = self.age_embedding_layer(dates)
214+
time_embeddings = self.time_embedding_layer(dates)
215215
positional_embeddings = self.positional_embedding_layer(visit_concept_orders)
216216
x = self.linear_proj(torch.cat([x, time_embeddings, age_embeddings, positional_embeddings], dim=-1))
217217
x = gelu_new(x)
@@ -298,21 +298,15 @@ def forward(
298298
# [batch_size, from_seq_length, to_seq_length]
299299
# ourselves in which case we just need to make it broadcastable to all heads.
300300
# The flash attention requires the original attention_mask
301-
if not getattr(self.config, "_attn_implementation", "eager") == "xformers":
302-
if seq_lens is not None:
303-
attention_mask = create_block_diagonal_mask(seq_lens)
304-
else:
305-
attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
301+
if seq_lens is None:
302+
attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
306303
else:
307-
if seq_lens is None:
308-
raise RuntimeError(
309-
f"seq_lens cannot be None when {getattr(self.config, '_attn_implementation', 'eager')} is used"
310-
)
311304
if not _is_package_available("xformers"):
305+
raise RuntimeError(f"seq_lens cannot be None when xformers is installed")
306+
if input_ids.shape[0] > 1:
312307
raise RuntimeError(
313-
f"xformers must be installed when {getattr(self.config, '_attn_implementation', 'eager')} is used"
308+
f"seq_lens is provided, which indicates sample packing, hence the batch_size must be one."
314309
)
315-
316310
seq_lens_list = seq_lens.flatten().to(torch.int).cpu().numpy().tolist()
317311
attention_mask = fmha.attn_bias.BlockDiagonalMask.from_seqlens(seq_lens_list, device=seq_lens.device)
318312

0 commit comments

Comments
 (0)