Skip to content

Commit b2c58af

Browse files
committed
do not use sample packing for running predicitons
1 parent b4366cd commit b2c58af

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

src/cehrbert/data_generators/hf_data_generator/hf_dataset_collator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ def __call__(self, examples):
116116
)
117117
# Set the visit_segments of the CLS token to a default value 0 because this doesn't belong to a visit
118118
batch["visit_segments"] = torch.cat([torch.full((batch_size, 1), 0), batch["visit_segments"]], dim=1)
119+
else:
120+
assert (
121+
batch["attention_mask"].shape[0] == 1
122+
), f"batch['attention_mask'].shape[0] must be 0 in sample packing"
119123

120124
# This is the most crucial logic for generating the training labels
121125
if self.is_pretraining:

src/cehrbert/runners/hf_cehrbert_finetune_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,12 @@ def main():
330330
dataset=processed_dataset["test"],
331331
batch_size=per_device_eval_batch_size,
332332
num_workers=training_args.dataloader_num_workers,
333-
collate_fn=data_collator,
333+
collate_fn=CehrBertDataCollator(
334+
tokenizer=tokenizer,
335+
max_length=config.max_position_embeddings,
336+
is_pretraining=False,
337+
mlm_probability=config.mlm_probability,
338+
),
334339
pin_memory=training_args.dataloader_pin_memory,
335340
)
336341
do_predict(test_dataloader, model_args, training_args)

0 commit comments

Comments
 (0)