Skip to content

Commit 509494e

Browse files
authored
make imdb dataset return raw data (#646)
1 parent cd1e30b commit 509494e

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

examples/language_model/bigbird/run_classifier.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import random
2020
from functools import partial
2121
import time
22+
import string
2223

2324
import numpy as np
2425
import paddle
@@ -47,6 +48,7 @@
4748
parser.add_argument("--seed", type=int, default=8, help="Random seed for initialization.")
4849
# yapf: enable
4950
args = parser.parse_args()
51+
TRANSLATOR = str.maketrans('', '', string.punctuation)
5052

5153

5254
def set_seed(args):
@@ -75,7 +77,8 @@ def _tokenize(text):
7577
def _collate_data(data, stack_fn=Stack()):
7678
num_fields = len(data[0])
7779
out = [None] * num_fields
78-
out[0] = stack_fn([_tokenize(x['text']) for x in data])
80+
out[0] = stack_fn(
81+
[_tokenize(x['text'].translate(TRANSLATOR)) for x in data])
7982
out[1] = stack_fn([x['label'] for x in data])
8083
seq_len = len(out[0][0])
8184
# Construct the random attention mask for the random attention

paddlenlp/datasets/imdb.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import json
1717
import io
1818
import os
19-
import string
2019

2120
import numpy as np
2221

@@ -54,8 +53,6 @@ def _get_data(self, mode, **kwargs):
5453
return data_dir
5554

5655
def _read(self, data_dir, *args):
57-
translator = str.maketrans('', '', string.punctuation)
58-
5956
for label in ["pos", "neg"]:
6057
root = os.path.join(data_dir, label)
6158
data_files = os.listdir(root)
@@ -69,7 +66,7 @@ def _read(self, data_dir, *args):
6966
f = os.path.join(root, f)
7067
with io.open(f, 'r', encoding='utf8') as fr:
7168
data = fr.readlines()
72-
data = data[0].translate(translator)
69+
data = data[0]
7370
yield {"text": data, "label": label_id}
7471

7572
def get_labels(self):

0 commit comments

Comments
 (0)