Skip to content

Commit 2e236db

Browse files
authored
Merge pull request #1919 from LemonNoel/fix_star
[ehealth] fix syntax for python 3.9
2 parents 4ba301b + 11881bc commit 2e236db

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

examples/biomedical/cblue/train_spo.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,8 @@ def batchify_fn(data):
128128
}): fn(samples)
129129
ent_label = [x['ent_label'] for x in data]
130130
spo_label = [x['spo_label'] for x in data]
131-
# data = input_ids, token_type_ids, position_ids, attention_mask
132-
data = _batchify_fn(data)
133-
batch_size, batch_len = data[0].shape
131+
input_ids, token_type_ids, position_ids, masks = _batchify_fn(data)
132+
batch_size, batch_len = input_ids.shape
134133
num_classes = len(train_ds.label_list)
135134
# Create one-hot labels.
136135
#
@@ -176,7 +175,7 @@ def batchify_fn(data):
176175
# xxx_label are used for metric computation.
177176
ent_label = [one_hot_ent_label, ent_label]
178177
spo_label = [one_hot_spo_label, spo_label]
179-
return (*data), ent_label, spo_label
178+
return input_ids, token_type_ids, position_ids, masks, ent_label, spo_label
180179

181180
train_data_loader = create_dataloader(
182181
train_ds,

paddlenlp/datasets/cblue.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def _read(self, filename, split):
316316
with open(filename, 'r', encoding='utf-8') as f:
317317
if self.name == 'CMeIE':
318318
for line in f.readlines():
319-
data = json.loads(line, encoding='urf-8')
319+
data = json.loads(line)
320320
labels = self.get_labels()
321321
label_map = dict([(x, i) for i, x in enumerate(labels)])
322322
data_list = data.get('spo_list', [])
@@ -353,7 +353,7 @@ def _read(self, filename, split):
353353

354354
yield data
355355
elif self.name == 'CMeEE':
356-
data_list = json.load(f, encoding='utf-8')
356+
data_list = json.load(f)
357357
for data in data_list:
358358
text_len = len(data[input_keys[0]])
359359
if data.get('entities', None):
@@ -386,7 +386,7 @@ def _read(self, filename, split):
386386
data = dict([(k, v) for k, v in zip(data_keys, data)])
387387
yield data
388388
else:
389-
data_list = json.load(f, encoding='utf-8')
389+
data_list = json.load(f)
390390
for data in data_list:
391391
if data.get('normalized_result', None):
392392
data['labels'] = [

0 commit comments

Comments
 (0)