Skip to content

Commit 9112b7e

Browse files
committed
fixed a bug when there is only one CLS token in the entire batch in compute_cehrbert_features
1 parent 0cf0f96 commit 9112b7e

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/cehrbert/linear_prob/compute_cehrbert_features.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,11 @@ def main():
276276
cls_token_index = torch.argmax((cls_token_indices).to(torch.int), dim=-1)
277277
features = cehrbert_output.last_hidden_state[..., cls_token_index, :].squeeze(axis=0)
278278
features = features.cpu().float().detach().numpy()
279+
280+
# This might happen sometimes
281+
if len(features) == cehrbert_model.config.hidden_size:
282+
features = [features]
283+
279284
assert len(features) == len(labels), "the number of features must match the number of labels"
280285
# Flatten features or handle them as a list of arrays (one array per row)
281286
features_list = [feature for feature in features]

0 commit comments

Comments
 (0)