|
36 | 36 | LOG = logging.get_logger("transformers") |
37 | 37 |
|
38 | 38 |
|
| 39 | +def extract_averaged_embeddings_from_packed_sequence( |
| 40 | + hidden_states: torch.Tensor, attention_mask: torch.Tensor |
| 41 | +) -> torch.Tensor: |
| 42 | + |
| 43 | + # Step 1: Find boundaries (where padding is 0) |
| 44 | + mask = attention_mask[0] # remove batch dimension for easier processing |
| 45 | + boundary_indices = (mask == 0).nonzero(as_tuple=False).flatten() |
| 46 | + |
| 47 | + # Add start and end manually |
| 48 | + start_indices = torch.cat([torch.tensor([-1]), boundary_indices]) |
| 49 | + end_indices = torch.cat([boundary_indices, torch.tensor([mask.size(0)])]) |
| 50 | + |
| 51 | + # Step 2: Extract embeddings between boundaries and average |
| 52 | + sample_embeddings = [] |
| 53 | + for start, end in zip(start_indices, end_indices): |
| 54 | + # Select embeddings between (start, end) |
| 55 | + # Skip if no valid tokens |
| 56 | + if end - start > 1: |
| 57 | + sample = hidden_states[0, start + 1 : end, :] # slice (start+1) to (end-1) |
| 58 | + avg_embedding = sample.mean(dim=0) # average over sequence length |
| 59 | + sample_embeddings.append(avg_embedding) |
| 60 | + # Stack results |
| 61 | + sample_embeddings = torch.stack(sample_embeddings, dim=0) |
| 62 | + return sample_embeddings |
| 63 | + |
| 64 | + |
39 | 65 | def prepare_finetune_dataset( |
40 | 66 | data_args: DataTrainingArguments, |
41 | 67 | training_args: TrainingArguments, |
@@ -286,10 +312,28 @@ def main(): |
286 | 312 |
|
287 | 313 | cls_token_indices = batch["input_ids"] == cehrgpt_tokenizer.cls_token_index |
288 | 314 | if cehrbert_args.sample_packing: |
289 | | - features = cehrbert_output.last_hidden_state[cls_token_indices].cpu().float().detach().numpy() |
| 315 | + if cehrbert_args.average_over_sequence: |
| 316 | + features = extract_averaged_embeddings_from_packed_sequence( |
| 317 | + cehrbert_output.last_hidden_state, batch["attention_mask"] |
| 318 | + ) |
| 319 | + else: |
| 320 | + features = cehrbert_output.last_hidden_state[cls_token_indices] |
| 321 | + features = features.cpu().float().detach().numpy() |
290 | 322 | else: |
291 | | - cls_token_index = torch.argmax((cls_token_indices).to(torch.int), dim=-1) |
292 | | - features = cehrbert_output.last_hidden_state[..., cls_token_index, :].cpu().float().detach().numpy() |
| 323 | + if cehrbert_args.average_over_sequence: |
| 324 | + features = torch.where( |
| 325 | + batch["attention_mask"].unsqueeze(dim=-1).to(torch.bool), |
| 326 | + cehrbert_output.last_hidden_state, |
| 327 | + 0, |
| 328 | + ) |
| 329 | + # Average across the sequence |
| 330 | + features = features.mean(dim=1) |
| 331 | + else: |
| 332 | + cls_token_index = torch.argmax((cls_token_indices).to(torch.int), dim=-1) |
| 333 | + features = ( |
| 334 | + cehrbert_output.last_hidden_state[..., cls_token_index, :].cpu().float().detach().numpy() |
| 335 | + ) |
| 336 | + features = features.cpu().float().detach().numpy() |
293 | 337 | assert len(features) == len(labels), "the number of features must match the number of labels" |
294 | 338 | # Flatten features or handle them as a list of arrays (one array per row) |
295 | 339 | features_list = [feature for feature in features] |
|
0 commit comments