Skip to content

Commit f08cac6

Browse files
committed
added an option to get the features by averaging over the entire sequence for each sample
1 parent 41275d7 commit f08cac6

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

src/cehrbert/linear_prob/compute_cehrbert_features.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,32 @@
3636
LOG = logging.get_logger("transformers")
3737

3838

39+
def extract_averaged_embeddings_from_packed_sequence(
40+
hidden_states: torch.Tensor, attention_mask: torch.Tensor
41+
) -> torch.Tensor:
42+
43+
# Step 1: Find boundaries (where padding is 0)
44+
mask = attention_mask[0] # remove batch dimension for easier processing
45+
boundary_indices = (mask == 0).nonzero(as_tuple=False).flatten()
46+
47+
# Add start and end manually
48+
start_indices = torch.cat([torch.tensor([-1]), boundary_indices])
49+
end_indices = torch.cat([boundary_indices, torch.tensor([mask.size(0)])])
50+
51+
# Step 2: Extract embeddings between boundaries and average
52+
sample_embeddings = []
53+
for start, end in zip(start_indices, end_indices):
54+
# Select embeddings between (start, end)
55+
# Skip if no valid tokens
56+
if end - start > 1:
57+
sample = hidden_states[0, start + 1 : end, :] # slice (start+1) to (end-1)
58+
avg_embedding = sample.mean(dim=0) # average over sequence length
59+
sample_embeddings.append(avg_embedding)
60+
# Stack results
61+
sample_embeddings = torch.stack(sample_embeddings, dim=0)
62+
return sample_embeddings
63+
64+
3965
def prepare_finetune_dataset(
4066
data_args: DataTrainingArguments,
4167
training_args: TrainingArguments,
@@ -286,10 +312,28 @@ def main():
286312

287313
cls_token_indices = batch["input_ids"] == cehrgpt_tokenizer.cls_token_index
288314
if cehrbert_args.sample_packing:
289-
features = cehrbert_output.last_hidden_state[cls_token_indices].cpu().float().detach().numpy()
315+
if cehrbert_args.average_over_sequence:
316+
features = extract_averaged_embeddings_from_packed_sequence(
317+
cehrbert_output.last_hidden_state, batch["attention_mask"]
318+
)
319+
else:
320+
features = cehrbert_output.last_hidden_state[cls_token_indices]
321+
features = features.cpu().float().detach().numpy()
290322
else:
291-
cls_token_index = torch.argmax((cls_token_indices).to(torch.int), dim=-1)
292-
features = cehrbert_output.last_hidden_state[..., cls_token_index, :].cpu().float().detach().numpy()
323+
if cehrbert_args.average_over_sequence:
324+
features = torch.where(
325+
batch["attention_mask"].unsqueeze(dim=-1).to(torch.bool),
326+
cehrbert_output.last_hidden_state,
327+
0,
328+
)
329+
# Average across the sequence
330+
features = features.mean(dim=1)
331+
else:
332+
cls_token_index = torch.argmax((cls_token_indices).to(torch.int), dim=-1)
333+
features = (
334+
cehrbert_output.last_hidden_state[..., cls_token_index, :].cpu().float().detach().numpy()
335+
)
336+
features = features.cpu().float().detach().numpy()
293337
assert len(features) == len(labels), "the number of features must match the number of labels"
294338
# Flatten features or handle them as a list of arrays (one array per row)
295339
features_list = [feature for feature in features]

src/cehrbert/runners/hf_runner_argument_dataclass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,7 @@ class CehrBertArguments:
341341
max_tokens_per_batch: int = dataclasses.field(
342342
default=16384, metadata={"help": "Maximum number of tokens in each batch"}
343343
)
344+
average_over_sequence: bool = dataclasses.field(
345+
default=False,
346+
metadata={"help": "Whether or not to average tokens per sequence"},
347+
)

0 commit comments

Comments
 (0)