|
5 | 5 |
|
6 | 6 | from model import RnnLm, CrossEntropyLossForLm, UpdateModel
|
7 | 7 | from args import parse_args
|
| 8 | +from reader import create_data_loader |
8 | 9 |
|
9 |
| -from paddlenlp.datasets import load_dataset |
10 | 10 | from paddlenlp.metrics import Perplexity
|
11 |
| -from paddlenlp.data import Vocab |
12 | 11 |
|
13 | 12 | paddle.seed(102)
|
14 | 13 |
|
15 | 14 |
|
16 |
| -def create_data_loader(batch_size, num_steps, data_path): |
17 |
| - train_ds, valid_ds, test_ds = load_dataset( |
18 |
| - 'ptb', splits=('train', 'valid', 'test')) |
19 |
| - |
20 |
| - train_examples = [ |
21 |
| - train_ds[i]['sentence'].split() for i in range(len(train_ds)) |
22 |
| - ] |
23 |
| - vocab = Vocab.build_vocab(train_examples, eos_token='</eos>') |
24 |
| - |
25 |
| - # Because the sentences in PTB dataset might be consecutive, we need to concatenate |
26 |
| - # all texts from our dataset and fold them into chunks while the number of rows is |
27 |
| - # equal to batch size. For example: |
28 |
| - # |
29 |
| - # Sentence1: we're talking about years ago before anyone heard of asbestos having |
30 |
| - # any questionable properties. |
31 |
| - # Sentence2: there is no asbestos in our products now. |
32 |
| - # Batch_size: 5 |
33 |
| - # Grouped_text: [["we're", "talking", "about", "years"], |
34 |
| - # ["ago", "before", "anyone", "heard"], |
35 |
| - # ["of", "asbestos", "having", "any"], |
36 |
| - # ["questionable", "properties", "there", "is"], |
37 |
| - # ["no", "asbestos", "in", "our"]] |
38 |
| - # |
39 |
| - def group_texts(examples): |
40 |
| - concat_examples = [] |
41 |
| - for example in examples: |
42 |
| - concat_examples += example['sentence'].split() + ['</eos>'] |
43 |
| - |
44 |
| - concat_examples = vocab.to_indices(concat_examples) |
45 |
| - |
46 |
| - max_seq_len = len(concat_examples) // batch_size |
47 |
| - reshaped_examples = np.asarray( |
48 |
| - concat_examples[0:batch_size * max_seq_len], dtype='int64').reshape( |
49 |
| - (batch_size, max_seq_len)) |
50 |
| - encoded_examples = [] |
51 |
| - for i in range(max_seq_len // num_steps): |
52 |
| - encoded_examples.append( |
53 |
| - (np.copy(reshaped_examples[:, i * num_steps:(i + 1) * |
54 |
| - num_steps]), |
55 |
| - np.copy(reshaped_examples[:, i * num_steps + 1:(i + 1) * |
56 |
| - num_steps + 1]))) |
57 |
| - |
58 |
| - return encoded_examples |
59 |
| - |
60 |
| - train_ds.map(group_texts, batched=True) |
61 |
| - valid_ds.map(group_texts, batched=True) |
62 |
| - test_ds.map(group_texts, batched=True) |
63 |
| - |
64 |
| - train_loader = paddle.io.DataLoader( |
65 |
| - train_ds, return_list=True, batch_size=None) |
66 |
| - valid_loader = paddle.io.DataLoader( |
67 |
| - valid_ds, return_list=True, batch_size=None) |
68 |
| - test_loader = paddle.io.DataLoader( |
69 |
| - test_ds, return_list=True, batch_size=None) |
70 |
| - return train_loader, valid_loader, test_loader, len(vocab) |
71 |
| - |
72 |
| - |
73 | 15 | def train(args):
|
74 | 16 | paddle.set_device(args.device)
|
75 | 17 | data_path = args.data_path
|
|
0 commit comments