Skip to content

Commit bda8fe9

Browse files
committed
do not use sample packing for running predicitons
1 parent b4366cd commit bda8fe9

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

src/cehrbert/data_generators/hf_data_generator/hf_dataset_collator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def __call__(self, examples):
9595
)
9696
# The attention_mask is set to 1 to enable attention for the CLS token
9797
batch["attention_mask"] = torch.cat([torch.full((batch_size, 1), 1.0), batch["attention_mask"]], dim=1)
98+
assert (
99+
batch["attention_mask"].shape[0] == 0
100+
), f"batch['attention_mask'].shape[0] must be 0 in sample packing"
98101
# Set the age of the CLS token to the starting age
99102
batch["ages"] = torch.cat([batch["ages"][:, 0:1], batch["ages"]], dim=1)
100103
# Set the age of the CLS token to the starting date

src/cehrbert/runners/hf_cehrbert_finetune_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,12 @@ def main():
330330
dataset=processed_dataset["test"],
331331
batch_size=per_device_eval_batch_size,
332332
num_workers=training_args.dataloader_num_workers,
333-
collate_fn=data_collator,
333+
collate_fn=CehrBertDataCollator(
334+
tokenizer=tokenizer,
335+
max_length=config.max_position_embeddings,
336+
is_pretraining=False,
337+
mlm_probability=config.mlm_probability,
338+
),
334339
pin_memory=training_args.dataloader_pin_memory,
335340
)
336341
do_predict(test_dataloader, model_args, training_args)

0 commit comments

Comments
 (0)