Skip to content

Commit 671f0f3

Browse files
committed
[ehealth] fix usage of star for python 3.9
1 parent 4ba301b commit 671f0f3

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

examples/biomedical/cblue/train_spo.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,8 @@ def batchify_fn(data):
128128
}): fn(samples)
129129
ent_label = [x['ent_label'] for x in data]
130130
spo_label = [x['spo_label'] for x in data]
131-
# data = input_ids, token_type_ids, position_ids, attention_mask
132-
data = _batchify_fn(data)
133-
batch_size, batch_len = data[0].shape
131+
input_ids, token_type_ids, position_ids, masks = _batchify_fn(data)
132+
batch_size, batch_len = input_ids.shape
134133
num_classes = len(train_ds.label_list)
135134
# Create one-hot labels.
136135
#
@@ -176,7 +175,7 @@ def batchify_fn(data):
176175
# xxx_label are used for metric computation.
177176
ent_label = [one_hot_ent_label, ent_label]
178177
spo_label = [one_hot_spo_label, spo_label]
179-
return (*data), ent_label, spo_label
178+
return input_ids, token_type_ids, position_ids, masks, ent_label, spo_label
180179

181180
train_data_loader = create_dataloader(
182181
train_ds,

0 commit comments

Comments
 (0)