Skip to content

Commit 325824d

Browse files
committed
corrected the logic for extract CLS embeddings in sample packing
1 parent 8cd7df4 commit 325824d

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

src/cehrbert/linear_prob/compute_cehrbert_features.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,7 @@ def main():
323323
cehrbert_output.last_hidden_state, batch["attention_mask"]
324324
)
325325
else:
326-
cls_token_index = torch.argmax((cls_token_indices).to(torch.int), dim=-1)
327-
features = cehrbert_output.last_hidden_state[..., cls_token_index, :].squeeze(axis=0)
326+
features = cehrbert_output.last_hidden_state[cls_token_indices, :].squeeze(axis=0)
328327
features = features.cpu().float().detach().numpy()
329328
else:
330329
if cehrbert_args.average_over_sequence:

0 commit comments

Comments
 (0)