|
| 1 | +""" |
| 2 | +imikolov's simple dataset: http://www.fit.vutbr.cz/~imikolov/rnnlm/ |
| 3 | +""" |
| 4 | +import paddle.v2.dataset.common |
| 5 | +import tarfile |
| 6 | + |
| 7 | +__all__ = ['train', 'test'] |
| 8 | + |
| 9 | +URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' |
| 10 | +MD5 = '30177ea32e27c525793142b6bf2c8e2d' |
| 11 | + |
| 12 | + |
| 13 | +def word_count(f, word_freq=None): |
| 14 | + add = paddle.v2.dataset.common.dict_add |
| 15 | + if word_freq == None: |
| 16 | + word_freq = {} |
| 17 | + |
| 18 | + for l in f: |
| 19 | + for w in l.strip().split(): |
| 20 | + add(word_freq, w) |
| 21 | + add(word_freq, '<s>') |
| 22 | + add(word_freq, '<e>') |
| 23 | + |
| 24 | + return word_freq |
| 25 | + |
| 26 | + |
| 27 | +def build_dict(train_filename, test_filename): |
| 28 | + with tarfile.open( |
| 29 | + paddle.v2.dataset.common.download( |
| 30 | + paddle.v2.dataset.imikolov.URL, 'imikolov', |
| 31 | + paddle.v2.dataset.imikolov.MD5)) as tf: |
| 32 | + trainf = tf.extractfile(train_filename) |
| 33 | + testf = tf.extractfile(test_filename) |
| 34 | + word_freq = word_count(testf, word_count(trainf)) |
| 35 | + |
| 36 | + TYPO_FREQ = 50 |
| 37 | + word_freq = filter(lambda x: x[1] > TYPO_FREQ, word_freq.items()) |
| 38 | + |
| 39 | + dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0])) |
| 40 | + words, _ = list(zip(*dictionary)) |
| 41 | + word_idx = dict(zip(words, xrange(len(words)))) |
| 42 | + word_idx['<unk>'] = len(words) |
| 43 | + |
| 44 | + return word_idx |
| 45 | + |
| 46 | + |
| 47 | +word_idx = {} |
| 48 | + |
| 49 | + |
| 50 | +def reader_creator(filename, n): |
| 51 | + global word_idx |
| 52 | + if len(word_idx) == 0: |
| 53 | + word_idx = build_dict('./simple-examples/data/ptb.train.txt', |
| 54 | + './simple-examples/data/ptb.valid.txt') |
| 55 | + |
| 56 | + def reader(): |
| 57 | + with tarfile.open( |
| 58 | + paddle.v2.dataset.common.download( |
| 59 | + paddle.v2.dataset.imikolov.URL, 'imikolov', |
| 60 | + paddle.v2.dataset.imikolov.MD5)) as tf: |
| 61 | + f = tf.extractfile(filename) |
| 62 | + |
| 63 | + UNK = word_idx['<unk>'] |
| 64 | + for l in f: |
| 65 | + l = ['<s>'] + l.strip().split() + ['<e>'] |
| 66 | + if len(l) >= n: |
| 67 | + l = [word_idx.get(w, UNK) for w in l] |
| 68 | + for i in range(n, len(l) + 1): |
| 69 | + yield tuple(l[i - n:i]) |
| 70 | + |
| 71 | + return reader |
| 72 | + |
| 73 | + |
| 74 | +def train(n): |
| 75 | + return reader_creator('./simple-examples/data/ptb.train.txt', n) |
| 76 | + |
| 77 | + |
| 78 | +def test(n): |
| 79 | + return reader_creator('./simple-examples/data/ptb.valid.txt', n) |
0 commit comments